cache comm stream

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-22 13:55:55 +00:00
parent 29a5ac1d04
commit 6d83b5ef3f
2 changed files with 18 additions and 14 deletions

View File

@ -217,6 +217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch # Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside # NOTE(Chen): Ideally, we should initialize the input batch inside
@ -1488,7 +1489,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return input_ids, positions, inputs_embeds, intermediate_tensors return input_ids, positions, inputs_embeds, intermediate_tensors
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, def _make_ubatch_metadata(self, ubatch_slices, attn_metadata,
compute_stream, num_tokens_across_dp, compute_stream, num_tokens_across_dp,
skip_cuda_graphs, skip_cuda_graphs,
scheduler_output) -> list[UbatchMetadata]: scheduler_output) -> list[UbatchMetadata]:
@ -1506,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_ctxs = make_ubatch_contexts( ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices), num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream, compute_stream=compute_stream,
forward_contexts=forward_contexts, forward_contexts=forward_contexts,
device=self.device) device=self.device)
@ -1584,7 +1586,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# run normal batch # run normal batch
else: else:
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
self._get_model_inputs(slice(0, num_scheduled_tokens), self._get_model_inputs(slice(0, num_scheduled_tokens),
scheduler_output) scheduler_output)
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
vllm_config=self.vllm_config, vllm_config=self.vllm_config,

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Optional from typing import Optional
@ -14,17 +15,16 @@ class UBatchContext:
Context manager for micro-batching synchronization using threading events. Context manager for micro-batching synchronization using threading events.
""" """
def __init__( def __init__(self,
self, id: int,
id: int, comm_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream, compute_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream, forward_context: ForwardContext,
forward_context: ForwardContext, cpu_wait_event: threading.Event,
cpu_wait_event: threading.Event, cpu_signal_event: threading.Event,
cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event,
gpu_comm_done_event: torch.cuda.Event, gpu_compute_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event, schedule: str = "default"):
schedule: str = "default"):
self.id = id self.id = id
self.comm_stream = comm_stream self.comm_stream = comm_stream
self.compute_stream = compute_stream self.compute_stream = compute_stream
@ -139,9 +139,11 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
if ctx is not None and ctx.schedule == schedule: if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute() ctx.yield_and_switch_from_comm_to_compute()
def make_ubatch_contexts( def make_ubatch_contexts(
num_micro_batches: int, num_micro_batches: int,
compute_stream: torch.cuda.Stream, compute_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext], forward_contexts: list[ForwardContext],
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
schedule: str = "default", schedule: str = "default",
@ -158,7 +160,7 @@ def make_ubatch_contexts(
torch.cuda.Event() for _ in range(num_micro_batches) torch.cuda.Event() for _ in range(num_micro_batches)
] ]
device = device or torch.cuda.current_device() device = device or torch.cuda.current_device()
comm_stream = torch.cuda.Stream(device) # comm_stream = torch.cuda.Stream(device)
assert len(forward_contexts) == 2 assert len(forward_contexts) == 2