mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 20:17:06 +08:00
use vllm current_stream
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
04f11d97a0
commit
2259b47951
@ -12,6 +12,7 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import current_stream
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder)
|
||||
@ -1326,10 +1327,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# since different ubatches may be on different streams than this one
|
||||
# they all need to wait on the gpu kernels launched in
|
||||
# _prepare_inputs before continuing
|
||||
if torch.cuda.current_stream() != root_stream:
|
||||
if current_stream() != root_stream:
|
||||
start_evt = torch.cuda.Event()
|
||||
# Make sure we wait then record so we don't miss the event
|
||||
torch.cuda.current_stream().wait_event(start_evt)
|
||||
current_stream().wait_event(start_evt)
|
||||
root_stream.record_event(start_evt)
|
||||
|
||||
model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt)
|
||||
@ -1337,16 +1338,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if save_results:
|
||||
results.append(model_output.clone())
|
||||
|
||||
if torch.cuda.current_stream() != root_stream:
|
||||
if current_stream() != root_stream:
|
||||
# Make the root stream for the ubatch to finish
|
||||
# Make sure we wait then record so we don't miss the event
|
||||
root_stream.wait_event(ubatch_ctx.done_evt)
|
||||
torch.cuda.current_stream().record_event(ubatch_ctx.done_evt)
|
||||
current_stream().record_event(ubatch_ctx.done_evt)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
|
||||
results = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = torch.cuda.current_stream()
|
||||
root_stream = current_stream()
|
||||
ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices),
|
||||
stream=root_stream, # Only works currently if everything is run on the same stream
|
||||
device=self.device)
|
||||
|
||||
@ -6,6 +6,8 @@ import torch.profiler as profiler
|
||||
from typing import Optional
|
||||
from torch.library import Library
|
||||
from torch.library import custom_op, register_kernel
|
||||
|
||||
from vllm.utils import current_stream
|
||||
from vllm import forward_context
|
||||
|
||||
class UBatchContext:
|
||||
@ -21,7 +23,7 @@ class UBatchContext:
|
||||
self.signal_event = signal_event
|
||||
self.schedule = schedule
|
||||
self.stream = stream
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
self.original_stream = current_stream()
|
||||
self.done_evt = torch.cuda.Event()
|
||||
self.forward_context = None
|
||||
|
||||
@ -32,7 +34,7 @@ class UBatchContext:
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
self.original_stream = current_stream()
|
||||
self.original_forward_context = forward_context._forward_context
|
||||
self.forward_context = self.original_forward_context
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user