cache comm stream

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-22 13:55:55 +00:00
parent 29a5ac1d04
commit 6d83b5ef3f
2 changed files with 18 additions and 14 deletions

View File

@ -217,6 +217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch # Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside # NOTE(Chen): Ideally, we should initialize the input batch inside
@ -1506,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_ctxs = make_ubatch_contexts( ubatch_ctxs = make_ubatch_contexts(
num_micro_batches=len(ubatch_slices), num_micro_batches=len(ubatch_slices),
comm_stream=self.comm_stream,
compute_stream=compute_stream, compute_stream=compute_stream,
forward_contexts=forward_contexts, forward_contexts=forward_contexts,
device=self.device) device=self.device)

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Optional from typing import Optional
@ -14,17 +15,16 @@ class UBatchContext:
Context manager for micro-batching synchronization using threading events. Context manager for micro-batching synchronization using threading events.
""" """
def __init__( def __init__(self,
self, id: int,
id: int, comm_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream, compute_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream, forward_context: ForwardContext,
forward_context: ForwardContext, cpu_wait_event: threading.Event,
cpu_wait_event: threading.Event, cpu_signal_event: threading.Event,
cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event,
gpu_comm_done_event: torch.cuda.Event, gpu_compute_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event, schedule: str = "default"):
schedule: str = "default"):
self.id = id self.id = id
self.comm_stream = comm_stream self.comm_stream = comm_stream
self.compute_stream = compute_stream self.compute_stream = compute_stream
@ -139,9 +139,11 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
if ctx is not None and ctx.schedule == schedule: if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute() ctx.yield_and_switch_from_comm_to_compute()
def make_ubatch_contexts( def make_ubatch_contexts(
num_micro_batches: int, num_micro_batches: int,
compute_stream: torch.cuda.Stream, compute_stream: torch.cuda.Stream,
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext], forward_contexts: list[ForwardContext],
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
schedule: str = "default", schedule: str = "default",
@ -158,7 +160,7 @@ def make_ubatch_contexts(
torch.cuda.Event() for _ in range(num_micro_batches) torch.cuda.Event() for _ in range(num_micro_batches)
] ]
device = device or torch.cuda.current_device() device = device or torch.cuda.current_device()
comm_stream = torch.cuda.Stream(device) # comm_stream = torch.cuda.Stream(device)
assert len(forward_contexts) == 2 assert len(forward_contexts) == 2