diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2e9b63c4cb00e..a8becaaf5f74c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,6 +12,7 @@ import torch import torch.distributed import torch.nn as nn +from vllm.utils import current_stream from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadataBuilder) @@ -1326,10 +1327,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # since different ubatches may be on different streams than this one # they all need to wait on the gpu kernels launched in # _prepare_inputs before continuing - if torch.cuda.current_stream() != root_stream: + if current_stream() != root_stream: start_evt = torch.cuda.Event() # Make sure we wait then record so we don't miss the event - torch.cuda.current_stream().wait_event(start_evt) + current_stream().wait_event(start_evt) root_stream.record_event(start_evt) model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt) @@ -1337,16 +1338,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): if save_results: results.append(model_output.clone()) - if torch.cuda.current_stream() != root_stream: + if current_stream() != root_stream: # Make the root stream for the ubatch to finish # Make sure we wait then record so we don't miss the event root_stream.wait_event(ubatch_ctx.done_evt) - torch.cuda.current_stream().record_event(ubatch_ctx.done_evt) + current_stream().record_event(ubatch_ctx.done_evt) def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run): results = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" - root_stream = torch.cuda.current_stream() + root_stream = current_stream() ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices), stream=root_stream, # Only works currently if everything is run on the same stream device=self.device) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 5ed4fb91458bc..024f64a07ff8c 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -6,6 +6,8 @@ import torch.profiler as profiler from typing import Optional from torch.library import Library from torch.library import custom_op, register_kernel + +from vllm.utils import current_stream from vllm import forward_context class UBatchContext: @@ -21,7 +23,7 @@ class UBatchContext: self.signal_event = signal_event self.schedule = schedule self.stream = stream - self.original_stream = torch.cuda.current_stream() + self.original_stream = current_stream() self.done_evt = torch.cuda.Event() self.forward_context = None @@ -32,7 +34,7 @@ class UBatchContext: global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self - self.original_stream = torch.cuda.current_stream() + self.original_stream = current_stream() self.original_forward_context = forward_context._forward_context self.forward_context = self.original_forward_context