mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 18:17:55 +08:00
cache comm stream
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
29a5ac1d04
commit
6d83b5ef3f
@ -217,6 +217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
|
self.comm_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
# Input Batch
|
# Input Batch
|
||||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||||
@ -1506,6 +1507,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
ubatch_ctxs = make_ubatch_contexts(
|
ubatch_ctxs = make_ubatch_contexts(
|
||||||
num_micro_batches=len(ubatch_slices),
|
num_micro_batches=len(ubatch_slices),
|
||||||
|
comm_stream=self.comm_stream,
|
||||||
compute_stream=compute_stream,
|
compute_stream=compute_stream,
|
||||||
forward_contexts=forward_contexts,
|
forward_contexts=forward_contexts,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -14,17 +15,16 @@ class UBatchContext:
|
|||||||
Context manager for micro-batching synchronization using threading events.
|
Context manager for micro-batching synchronization using threading events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
id: int,
|
||||||
id: int,
|
comm_stream: torch.cuda.Stream,
|
||||||
comm_stream: torch.cuda.Stream,
|
compute_stream: torch.cuda.Stream,
|
||||||
compute_stream: torch.cuda.Stream,
|
forward_context: ForwardContext,
|
||||||
forward_context: ForwardContext,
|
cpu_wait_event: threading.Event,
|
||||||
cpu_wait_event: threading.Event,
|
cpu_signal_event: threading.Event,
|
||||||
cpu_signal_event: threading.Event,
|
gpu_comm_done_event: torch.cuda.Event,
|
||||||
gpu_comm_done_event: torch.cuda.Event,
|
gpu_compute_done_event: torch.cuda.Event,
|
||||||
gpu_compute_done_event: torch.cuda.Event,
|
schedule: str = "default"):
|
||||||
schedule: str = "default"):
|
|
||||||
self.id = id
|
self.id = id
|
||||||
self.comm_stream = comm_stream
|
self.comm_stream = comm_stream
|
||||||
self.compute_stream = compute_stream
|
self.compute_stream = compute_stream
|
||||||
@ -139,9 +139,11 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
|
|||||||
if ctx is not None and ctx.schedule == schedule:
|
if ctx is not None and ctx.schedule == schedule:
|
||||||
ctx.yield_and_switch_from_comm_to_compute()
|
ctx.yield_and_switch_from_comm_to_compute()
|
||||||
|
|
||||||
|
|
||||||
def make_ubatch_contexts(
|
def make_ubatch_contexts(
|
||||||
num_micro_batches: int,
|
num_micro_batches: int,
|
||||||
compute_stream: torch.cuda.Stream,
|
compute_stream: torch.cuda.Stream,
|
||||||
|
comm_stream: torch.cuda.Stream,
|
||||||
forward_contexts: list[ForwardContext],
|
forward_contexts: list[ForwardContext],
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
schedule: str = "default",
|
schedule: str = "default",
|
||||||
@ -158,7 +160,7 @@ def make_ubatch_contexts(
|
|||||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||||
]
|
]
|
||||||
device = device or torch.cuda.current_device()
|
device = device or torch.cuda.current_device()
|
||||||
comm_stream = torch.cuda.Stream(device)
|
# comm_stream = torch.cuda.Stream(device)
|
||||||
|
|
||||||
assert len(forward_contexts) == 2
|
assert len(forward_contexts) == 2
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user