use vllm current_stream

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-21 04:53:50 +00:00
parent 04f11d97a0
commit 2259b47951
2 changed files with 10 additions and 7 deletions

View File

@ -12,6 +12,7 @@ import torch
import torch.distributed
import torch.nn as nn
from vllm.utils import current_stream
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder)
@ -1326,10 +1327,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# since different ubatches may be on different streams than this one
# they all need to wait on the gpu kernels launched in
# _prepare_inputs before continuing
if torch.cuda.current_stream() != root_stream:
if current_stream() != root_stream:
start_evt = torch.cuda.Event()
# Make sure we wait then record so we don't miss the event
torch.cuda.current_stream().wait_event(start_evt)
current_stream().wait_event(start_evt)
root_stream.record_event(start_evt)
model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt)
@ -1337,16 +1338,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if save_results:
results.append(model_output.clone())
if torch.cuda.current_stream() != root_stream:
if current_stream() != root_stream:
# Make the root stream for the ubatch to finish
# Make sure we wait then record so we don't miss the event
root_stream.wait_event(ubatch_ctx.done_evt)
torch.cuda.current_stream().record_event(ubatch_ctx.done_evt)
current_stream().record_event(ubatch_ctx.done_evt)
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
results = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = torch.cuda.current_stream()
root_stream = current_stream()
ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices),
stream=root_stream, # Only works currently if everything is run on the same stream
device=self.device)

View File

@ -6,6 +6,8 @@ import torch.profiler as profiler
from typing import Optional
from torch.library import Library
from torch.library import custom_op, register_kernel
from vllm.utils import current_stream
from vllm import forward_context
class UBatchContext:
@ -21,7 +23,7 @@ class UBatchContext:
self.signal_event = signal_event
self.schedule = schedule
self.stream = stream
self.original_stream = torch.cuda.current_stream()
self.original_stream = current_stream()
self.done_evt = torch.cuda.Event()
self.forward_context = None
@ -32,7 +34,7 @@ class UBatchContext:
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = self
self.original_stream = torch.cuda.current_stream()
self.original_stream = current_stream()
self.original_forward_context = forward_context._forward_context
self.forward_context = self.original_forward_context