mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 02:07:02 +08:00
cache comm stream
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
29a5ac1d04
commit
6d83b5ef3f
@ -217,6 +217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
@ -1488,7 +1489,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata,
|
||||
compute_stream, num_tokens_across_dp,
|
||||
compute_stream, num_tokens_across_dp,
|
||||
skip_cuda_graphs,
|
||||
scheduler_output) -> list[UbatchMetadata]:
|
||||
|
||||
@ -1506,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
num_micro_batches=len(ubatch_slices),
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
device=self.device)
|
||||
@ -1584,7 +1586,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# run normal batch
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
||||
self._get_model_inputs(slice(0, num_scheduled_tokens),
|
||||
scheduler_output)
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
@ -14,17 +15,16 @@ class UBatchContext:
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
forward_context: ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
def __init__(self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
forward_context: ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
@ -139,9 +139,11 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_and_switch_from_comm_to_compute()
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
device: Optional[torch.device] = None,
|
||||
schedule: str = "default",
|
||||
@ -158,7 +160,7 @@ def make_ubatch_contexts(
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
device = device or torch.cuda.current_device()
|
||||
comm_stream = torch.cuda.Stream(device)
|
||||
# comm_stream = torch.cuda.Stream(device)
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user