cache comm stream

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-22 13:55:55 +00:00
parent 29a5ac1d04
commit 6d83b5ef3f
2 changed files with 18 additions and 14 deletions

View File

@ -217,6 +217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
@ -1488,7 +1489,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return input_ids, positions, inputs_embeds, intermediate_tensors
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata,
compute_stream, num_tokens_across_dp,
compute_stream, num_tokens_across_dp,
skip_cuda_graphs,
scheduler_output) -> list[UbatchMetadata]:
@ -1506,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
device=self.device)
@ -1584,7 +1586,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# run normal batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
self._get_model_inputs(slice(0, num_scheduled_tokens),
self._get_model_inputs(slice(0, num_scheduled_tokens),
scheduler_output)
with set_forward_context(attn_metadata,
vllm_config=self.vllm_config,

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
@ -14,17 +15,16 @@ class UBatchContext:
Context manager for micro-batching synchronization using threading events.
"""
def __init__(
self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
def __init__(self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
forward_context: ForwardContext,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
@ -139,9 +139,11 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute()
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext],
device: Optional[torch.device] = None,
schedule: str = "default",
@ -158,7 +160,7 @@ def make_ubatch_contexts(
torch.cuda.Event() for _ in range(num_micro_batches)
]
device = device or torch.cuda.current_device()
comm_stream = torch.cuda.Stream(device)
# comm_stream = torch.cuda.Stream(device)
assert len(forward_contexts) == 2