mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:55:57 +08:00
232 lines
7.8 KiB
Python
232 lines
7.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import threading
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm import forward_context
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.utils.torch_utils import current_stream
|
|
|
|
_THREAD_ID_TO_CONTEXT: dict = {}
|
|
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None]
|
|
|
|
|
|
class UBatchContext:
|
|
"""
|
|
Context manager for micro-batching synchronization using threading events.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
id: int,
|
|
comm_stream: torch.cuda.Stream,
|
|
compute_stream: torch.cuda.Stream,
|
|
forward_context: ForwardContext,
|
|
ready_barrier: threading.Barrier,
|
|
cpu_wait_event: threading.Event,
|
|
cpu_signal_event: threading.Event,
|
|
gpu_comm_done_event: torch.cuda.Event,
|
|
gpu_compute_done_event: torch.cuda.Event,
|
|
schedule: str = "default",
|
|
):
|
|
self.id = id
|
|
self.comm_stream = comm_stream
|
|
self.compute_stream = compute_stream
|
|
self.forward_context = forward_context
|
|
self.ready_barrier = ready_barrier
|
|
self.cpu_wait_event = cpu_wait_event
|
|
self.cpu_signal_event = cpu_signal_event
|
|
self.current_stream = compute_stream
|
|
self.gpu_comm_done_event = gpu_comm_done_event
|
|
self.gpu_compute_done_event = gpu_compute_done_event
|
|
self.schedule = schedule
|
|
self.recv_hook = None
|
|
|
|
def __enter__(self):
|
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
|
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
|
|
_CURRENT_CONTEXTS[self.id] = self
|
|
self.ready_barrier.wait()
|
|
|
|
self.cpu_wait_event.wait()
|
|
self.cpu_wait_event.clear()
|
|
self._restore_context()
|
|
# Assume we want to start on the compute stream
|
|
self.update_stream(self.compute_stream)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
|
|
_CURRENT_CONTEXTS[self.id] = None
|
|
del _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
|
self.maybe_run_recv_hook()
|
|
self.cpu_signal_event.set()
|
|
self.cpu_wait_event.clear()
|
|
return False
|
|
|
|
def _restore_context(self):
|
|
forward_context._forward_context = self.forward_context
|
|
|
|
def update_stream(self, stream):
|
|
self.current_stream = stream
|
|
if current_stream() != self.current_stream:
|
|
torch.cuda.set_stream(self.current_stream)
|
|
|
|
def _signal_comm_done(self):
|
|
self.gpu_comm_done_event.record(self.comm_stream)
|
|
|
|
def _signal_compute_done(self):
|
|
self.gpu_compute_done_event.record(self.compute_stream)
|
|
|
|
def _wait_compute_done(self):
|
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
|
|
|
def _wait_comm_done(self):
|
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
|
|
|
def _cpu_yield(self):
|
|
# It is critical for correctness that only one thread is running
|
|
# at a time. These asserts just make sure that this is the only
|
|
# thread running before waking the other one up and going to sleep
|
|
assert forward_context._forward_context == self.forward_context
|
|
assert current_stream() == self.current_stream
|
|
assert not self.cpu_wait_event.is_set()
|
|
|
|
self.cpu_signal_event.set()
|
|
self.cpu_wait_event.wait()
|
|
self.cpu_wait_event.clear()
|
|
self._restore_context()
|
|
|
|
def switch_to_comm(self):
|
|
self.update_stream(self.comm_stream)
|
|
|
|
def switch_to_compute(self):
|
|
self.update_stream(self.compute_stream)
|
|
|
|
def switch_to_comm_sync(self):
|
|
self._signal_compute_done()
|
|
self.update_stream(self.comm_stream)
|
|
self._wait_compute_done()
|
|
|
|
def switch_to_compute_sync(self):
|
|
self._signal_comm_done()
|
|
self.update_stream(self.compute_stream)
|
|
self._wait_comm_done()
|
|
|
|
def maybe_run_recv_hook(self):
|
|
if self.recv_hook is not None:
|
|
self.recv_hook()
|
|
self.recv_hook = None
|
|
|
|
def yield_(self):
|
|
self.current_stream = current_stream()
|
|
self._cpu_yield()
|
|
self.update_stream(self.current_stream)
|
|
|
|
def yield_and_switch_from_compute_to_comm(self):
|
|
assert current_stream() == self.compute_stream
|
|
self._signal_compute_done()
|
|
self._cpu_yield()
|
|
assert self.current_stream == self.compute_stream
|
|
self.update_stream(self.comm_stream)
|
|
self._wait_compute_done()
|
|
|
|
def yield_and_switch_from_comm_to_compute(self):
|
|
assert current_stream() == self.comm_stream
|
|
self._signal_comm_done()
|
|
self._cpu_yield()
|
|
assert self.current_stream == self.comm_stream
|
|
self.update_stream(self.compute_stream)
|
|
self._wait_comm_done()
|
|
|
|
|
|
def dbo_enabled() -> bool:
|
|
return len(_THREAD_ID_TO_CONTEXT) > 0
|
|
|
|
|
|
def dbo_current_ubatch_id() -> int:
|
|
if len(_THREAD_ID_TO_CONTEXT) == 0:
|
|
return 0
|
|
return _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
|
|
|
|
|
def _register_ubatch_function(func):
|
|
def wrapper(*args, **kwargs):
|
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
|
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
|
func(ctx, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook)
|
|
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
|
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
|
UBatchContext.yield_and_switch_from_compute_to_comm
|
|
)
|
|
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
|
UBatchContext.yield_and_switch_from_comm_to_compute
|
|
)
|
|
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
|
|
dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute)
|
|
dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync)
|
|
dbo_switch_to_compute_sync = _register_ubatch_function(
|
|
UBatchContext.switch_to_compute_sync
|
|
)
|
|
|
|
|
|
def dbo_register_recv_hook(recv_hook):
|
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
|
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2]
|
|
next_ctx.recv_hook = recv_hook
|
|
|
|
|
|
def dbo_get_previous_event(func, *args, **kwargs):
|
|
if len(_THREAD_ID_TO_CONTEXT) > 0:
|
|
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
|
|
ctx = _CURRENT_CONTEXTS[ctx_idx]
|
|
# execute callable on the ubatch compute stream to record/wait events there
|
|
with torch.cuda.stream(ctx.compute_stream):
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
def make_ubatch_contexts(
|
|
num_micro_batches: int,
|
|
compute_stream: torch.cuda.Stream,
|
|
comm_stream: torch.cuda.Stream,
|
|
forward_contexts: list[ForwardContext],
|
|
ready_barrier: threading.Barrier,
|
|
schedule: str = "default",
|
|
) -> list[UBatchContext]:
|
|
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
|
"""
|
|
Create a context manager for micro-batching synchronization.
|
|
"""
|
|
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
|
gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
|
|
gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
|
|
|
|
assert len(forward_contexts) == 2
|
|
|
|
ctxs = []
|
|
for i in range(num_micro_batches):
|
|
ctx = UBatchContext(
|
|
id=i,
|
|
compute_stream=compute_stream,
|
|
comm_stream=comm_stream,
|
|
forward_context=forward_contexts[i],
|
|
ready_barrier=ready_barrier,
|
|
cpu_wait_event=cpu_events[i],
|
|
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
|
gpu_comm_done_event=gpu_comm_done_events[i],
|
|
gpu_compute_done_event=gpu_compute_done_events[i],
|
|
schedule=schedule,
|
|
)
|
|
ctxs.append(ctx)
|
|
|
|
return ctxs
|