diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9f79ac3a5f9f4..a0e04dbac9213 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -217,6 +217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Request states. self.requests: dict[str, CachedRequestState] = {} + self.comm_stream = torch.cuda.Stream() # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside @@ -1488,7 +1489,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): return input_ids, positions, inputs_embeds, intermediate_tensors def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, - compute_stream, num_tokens_across_dp, + compute_stream, num_tokens_across_dp, skip_cuda_graphs, scheduler_output) -> list[UbatchMetadata]: @@ -1506,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ubatch_ctxs = make_ubatch_contexts( num_micro_batches=len(ubatch_slices), + comm_stream=self.comm_stream, compute_stream=compute_stream, forward_contexts=forward_contexts, device=self.device) @@ -1584,7 +1586,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ - self._get_model_inputs(slice(0, num_scheduled_tokens), + self._get_model_inputs(slice(0, num_scheduled_tokens), scheduler_output) with set_forward_context(attn_metadata, vllm_config=self.vllm_config, diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index b16854c9375a5..433cf28bef1c4 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading from typing import Optional @@ -14,17 +15,16 @@ class UBatchContext: Context manager for micro-batching synchronization using threading events. """ - def __init__( - self, - id: int, - comm_stream: torch.cuda.Stream, - compute_stream: torch.cuda.Stream, - forward_context: ForwardContext, - cpu_wait_event: threading.Event, - cpu_signal_event: threading.Event, - gpu_comm_done_event: torch.cuda.Event, - gpu_compute_done_event: torch.cuda.Event, - schedule: str = "default"): + def __init__(self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + forward_context: ForwardContext, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default"): self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream @@ -139,9 +139,11 @@ def yield_and_switch_from_comm_to_compute(schedule="default"): if ctx is not None and ctx.schedule == schedule: ctx.yield_and_switch_from_comm_to_compute() + def make_ubatch_contexts( num_micro_batches: int, compute_stream: torch.cuda.Stream, + comm_stream: torch.cuda.Stream, forward_contexts: list[ForwardContext], device: Optional[torch.device] = None, schedule: str = "default", @@ -158,7 +160,7 @@ def make_ubatch_contexts( torch.cuda.Event() for _ in range(num_micro_batches) ] device = device or torch.cuda.current_device() - comm_stream = torch.cuda.Stream(device) + # comm_stream = torch.cuda.Stream(device) assert len(forward_contexts) == 2