mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 01:17:03 +08:00
wip seperate comm and compute threads
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
2f3920638c
commit
7b31e8a8ff
@ -7,7 +7,10 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.v1.worker.ubatching import get_current_ubatch_context, yield_impl
|
||||
from vllm.v1.worker.ubatching import (
|
||||
get_current_ubatch_context, yield_and_switch_from_compute_to_comm_impl,
|
||||
yield_and_switch_from_comm_to_compute_impl
|
||||
)
|
||||
|
||||
|
||||
# Note use: layer.get_all_to_all() to get an AllToAll instance
|
||||
@ -119,14 +122,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
#print("Dispatch pre-wait")
|
||||
if (ubatch_ctx := get_current_ubatch_context()) is not None:
|
||||
ubatch_ctx.gpu_stream_wait()
|
||||
#print("Dispatch launched")
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
dispatch(True) # Send
|
||||
yield_impl(gpu_wait=False)
|
||||
dispatch(False) # Recv
|
||||
#print("Finished dispatch")
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens
|
||||
|
||||
@ -164,11 +163,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
#print("Combine pre-wait")
|
||||
if (ubatch_ctx := get_current_ubatch_context()) is not None:
|
||||
ubatch_ctx.gpu_stream_wait()
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
combine(True)
|
||||
#print("Combine launched")
|
||||
yield_impl(gpu_wait=False)
|
||||
combine(False)
|
||||
#print("Finished combine")
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
@ -59,7 +59,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatching import make_ubatch_context_chain, UBatchContext
|
||||
from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext
|
||||
|
||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
@ -1342,19 +1342,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# attn_metadata[i] if attn_metadata is not None else None,
|
||||
# self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
|
||||
# ) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
|
||||
ubatch_ctxs, start_hook = make_ubatch_context_chain(
|
||||
ubatch_ctxs, start_hook = make_ubatch_contexts(
|
||||
len(ubatch_slices),
|
||||
#fwd_ctxs=ubatch_fwd_ctxs,
|
||||
streams=self.ubatch_streams, #stream=root_stream, # Only works currently if everything is run on the same stream
|
||||
compute_stream=root_stream,
|
||||
device=self.device)
|
||||
setup_done = threading.Event()
|
||||
ubatch_threads = []
|
||||
|
||||
# Initialize Events? not sure if this helps
|
||||
for ubatch_ctx in ubatch_ctxs:
|
||||
ubatch_ctx.gpu_wait_event.record(ubatch_ctx.stream)
|
||||
ubatch_ctx.stream.wait_event(ubatch_ctx.gpu_wait_event)
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
@ -1388,9 +1382,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
for ubatch_ctx in ubatch_ctxs:
|
||||
root_stream.wait_stream(ubatch_ctx.stream)
|
||||
|
||||
torch.cuda.set_stream(root_stream)
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
@ -17,36 +17,31 @@ class UBatchContext:
|
||||
"""
|
||||
def __init__(self,
|
||||
id: int,
|
||||
stream: torch.cuda.Stream,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
#fwd_ctx: forward_context.ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_wait_event: torch.cuda.Event,
|
||||
gpu_signal_event: torch.cuda.Event,
|
||||
gpu_wait_on_launch: bool = False,
|
||||
schedule="default"):
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.stream = stream
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.original_stream = current_stream()
|
||||
self.forward_context = None #fwd_ctx
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.gpu_wait_event = gpu_wait_event
|
||||
self.gpu_signal_event = gpu_signal_event
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.schedule = schedule
|
||||
self.done_event = torch.cuda.Event()
|
||||
self.gpu_wait_on_launch = gpu_wait_on_launch
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
self._cpu_wait()
|
||||
# start_event = torch.cuda.Event()
|
||||
# self.original_stream.record_event(start_event)
|
||||
# self.stream.wait_event(start_event)
|
||||
print("Starting ubatch %d" % self.id)
|
||||
# if self.gpu_wait_on_launch:
|
||||
self.gpu_stream_wait()
|
||||
# Assume we start on the compute stream
|
||||
assert current_stream() == self.compute_stream, \
|
||||
"Expected to start on the compute stream, but found %s" % current_stream()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -54,9 +49,6 @@ class UBatchContext:
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
torch.cuda.set_stream(self.original_stream)
|
||||
print("Finishing ubatch %d" % self.id)
|
||||
self._signal()
|
||||
self._signal()
|
||||
self._signal()
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
@ -65,66 +57,37 @@ class UBatchContext:
|
||||
torch.cuda.set_stream(self.stream)
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
# Seperate GPU wait so we can do
|
||||
# ubatch0
|
||||
# 1) work
|
||||
# 2) dispatch
|
||||
# 3) yield
|
||||
# ubatch1
|
||||
# 1) work
|
||||
# 2) gpu wait
|
||||
# 3) dispatch
|
||||
# 4) yield
|
||||
#
|
||||
# This way we can have the CPU schedule ubatch1-dispatch while ubatch0
|
||||
# before yielding back to ubatch1 but ensure we wont start the dispatch
|
||||
# until ubatch0-dispatch is done avoiding overlapping dispatches that
|
||||
# might share underlying buffers
|
||||
#
|
||||
# NOTE(lucas): I think we need to do:
|
||||
# ubatch0
|
||||
# - work
|
||||
# - dispatch send
|
||||
# - yield
|
||||
# ubatch1
|
||||
# - work
|
||||
# - yield
|
||||
# ubatch0
|
||||
# - dispatch recv
|
||||
# - gpu record, event0
|
||||
# - yield
|
||||
# ubatch1
|
||||
# - gpu wait, event0
|
||||
# - dispatch send
|
||||
# - yield
|
||||
# ubatch0
|
||||
# - work
|
||||
# .....
|
||||
# To ensure we record the cuda event before waiting
|
||||
def gpu_stream_wait(self):
|
||||
print("Waiting ubatch %d on %s in stream %s" % (self.id, self.gpu_wait_event, self.stream))
|
||||
self.stream.wait_event(self.gpu_wait_event)
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _yield(self, gpu_wait: bool = True):
|
||||
#print("Yielding ubatch %d" % self.id)
|
||||
self._signal()
|
||||
self._cpu_wait()
|
||||
#print("Resuming ubatch %d" % self.id)
|
||||
if gpu_wait:
|
||||
self.gpu_stream_wait()
|
||||
def _wait_compute_done(self):
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
|
||||
def _signal(self):
|
||||
# Wait for the next batch to signal back
|
||||
print(f"signaling ubatch {self.id} to {self.gpu_signal_event} on {self.stream}")
|
||||
self.gpu_signal_event.record(self.stream)
|
||||
# Signal that this batch reached the barrier
|
||||
def _wait_comm_done(self):
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
|
||||
def _cpu_yield(self, gpu_wait: bool = True):
|
||||
self.cpu_signal_event.set()
|
||||
|
||||
def _cpu_wait(self):
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
torch.cuda.set_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
torch.cuda.set_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
@ -134,23 +97,36 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
"""
|
||||
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||
|
||||
def yield_impl(schedule="default", gpu_wait: bool = True):
|
||||
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
#print("you are in yield_impl", ctx)
|
||||
if ctx is not None:
|
||||
ctx._yield(gpu_wait=gpu_wait)
|
||||
ctx.yield_and_switch_from_compute_to_comm()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
if ctx is not None:
|
||||
ctx.yield_and_switch_from_comm_to_compute()
|
||||
|
||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||
# optimizing it away (TODO: see if this is actually needed)
|
||||
@custom_op("vllm::yield_", mutates_args=("x",))
|
||||
def yield_(x: torch.Tensor, schedule: str="default") -> None:
|
||||
yield_impl(schedule)
|
||||
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",))
|
||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule)
|
||||
|
||||
# 3) Fake implementation for shape prop and FX tracing
|
||||
@yield_.register_fake
|
||||
def yield_(x: torch.Tensor, schedule: str="default") -> None:
|
||||
@yield_and_switch_from_compute_to_comm.register_fake
|
||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
|
||||
pass
|
||||
|
||||
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",))
|
||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule)
|
||||
|
||||
@yield_and_switch_from_comm_to_compute.register_fake
|
||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
|
||||
pass
|
||||
|
||||
def dump_ubatching_state():
|
||||
@ -169,16 +145,13 @@ def dump_ubatching_state():
|
||||
f" CPU Signal Event: {ctx.cpu_signal_event}\n"
|
||||
f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
def make_ubatch_context_chain(
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
#fwd_ctxs: forward_context.ForwardContext,
|
||||
streams: Optional[list[torch.Stream]] = None,
|
||||
device: Optional[torch.device] = None
|
||||
compute_stream: torch.cuda.Stream,
|
||||
device: Optional[torch.device] = None,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
|
||||
@ -186,26 +159,26 @@ def make_ubatch_context_chain(
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_events = [torch.cuda.Event(blocking=True) for _ in range(num_micro_batches)]
|
||||
gpu_comm_done_events = [
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
gpu_compute_done_events = [
|
||||
torch.cuda.Event() for _ in range(num_micro_batches)
|
||||
]
|
||||
device = device or torch.cuda.current_device()
|
||||
|
||||
comm_stream = torch.cuda.Stream(device)
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
stream = (streams[i] if streams else None) or torch.cuda.Stream(device)
|
||||
ctx = UBatchContext(id=i,
|
||||
stream=stream,
|
||||
#fwd_ctx=fwd_ctxs[i],
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
||||
gpu_wait_event=gpu_events[i],
|
||||
gpu_signal_event=gpu_events[(i + 1) % num_micro_batches],
|
||||
gpu_wait_on_launch=(i > 0),
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule
|
||||
)
|
||||
ctxs.append(ctx)
|
||||
|
||||
def start_hook(from_stream: torch.cuda.Stream):
|
||||
ctxs[0].gpu_wait_event.record(from_stream)
|
||||
print('singal to ubatch %d event %s from stream %s' % (ctxs[0].id, ctxs[0].gpu_wait_event, from_stream))
|
||||
ctxs[0].cpu_wait_event.set()
|
||||
|
||||
return ctxs, start_hook
|
||||
return ctxs,
|
||||
Loading…
x
Reference in New Issue
Block a user