mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 15:57:05 +08:00
add support for splitting dispatch/combine deepep ll kernels
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
1c41175b2a
commit
582d301f98
@ -1937,6 +1937,9 @@ class ParallelConfig:
|
||||
request is greater than this threshold, microbatching will be used.
|
||||
Otherwise, the request will be processed in a single batch."""
|
||||
|
||||
enable_async_comms: bool = False
|
||||
"""enable async comms"""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
|
||||
@ -304,6 +304,7 @@ class EngineArgs:
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
enable_microbatching: bool = ParallelConfig.enable_microbatching
|
||||
enable_async_comms: bool = ParallelConfig.enable_async_comms
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
num_redundant_experts: int = ParallelConfig.num_redundant_experts
|
||||
eplb_window_size: int = ParallelConfig.eplb_window_size
|
||||
@ -640,6 +641,8 @@ class EngineArgs:
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
parallel_group.add_argument("--enable-microbatching",
|
||||
**parallel_kwargs["enable_microbatching"])
|
||||
parallel_group.add_argument("--enable-async-comms",
|
||||
**parallel_kwargs["enable_async_comms"])
|
||||
parallel_group.add_argument("--enable-eplb",
|
||||
**parallel_kwargs["enable_eplb"])
|
||||
parallel_group.add_argument("--num-redundant-experts",
|
||||
@ -1166,6 +1169,7 @@ class EngineArgs:
|
||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_microbatching=self.enable_microbatching,
|
||||
enable_async_comms=self.enable_async_comms,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
eplb_window_size=self.eplb_window_size,
|
||||
|
||||
@ -127,6 +127,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
hidden_size = a1.size(1)
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
||||
do_recv_hook = True if ubatch_ctx is not None and \
|
||||
ubatch_ctx.enable_async_comms else False
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
assert hidden_size % 128 == 0, \
|
||||
@ -147,16 +149,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# Dispatch
|
||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||
expert_x, expert_num_tokens, handle, _, _= \
|
||||
expert_x, expert_num_tokens, handle, _, recv_hook= \
|
||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||
topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
return_recv_hook=do_recv_hook)
|
||||
self.handles[a2a_idx] = handle
|
||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook)
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(
|
||||
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||
@ -178,6 +180,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
||||
handle = self.handles[a2a_idx]
|
||||
do_recv_hook = True if ubatch_ctx is not None and \
|
||||
ubatch_ctx.enable_async_comms else False
|
||||
assert handle is not None
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
@ -187,12 +191,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||
_ = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
|
||||
_, _, recv_hook = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
|
||||
topk_ids,
|
||||
combine_topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
zero_copy=False,
|
||||
return_recv_hook=False,
|
||||
return_recv_hook=do_recv_hook,
|
||||
out=output)
|
||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook)
|
||||
|
||||
@ -1656,7 +1656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
comm_stream=self.comm_stream,
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
device=self.device)
|
||||
device=self.device,
|
||||
enable_async_comms=self.parallel_config.enable_async_comms)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
|
||||
@ -24,6 +24,7 @@ class UBatchContext:
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
enable_async_comms: bool,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
@ -34,6 +35,7 @@ class UBatchContext:
|
||||
self.current_stream = compute_stream
|
||||
self.gpu_comm_done_event = gpu_comm_done_event
|
||||
self.gpu_compute_done_event = gpu_compute_done_event
|
||||
self.enable_async_comms = enable_async_comms
|
||||
self.schedule = schedule
|
||||
|
||||
def __enter__(self):
|
||||
@ -106,10 +108,14 @@ class UBatchContext:
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
def yield_and_switch_from_comm_to_compute(self, recv_hook = None):
|
||||
assert current_stream() == self.comm_stream
|
||||
self._signal_comm_done()
|
||||
if recv_hook is None:
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
if recv_hook is not None:
|
||||
recv_hook()
|
||||
self._signal_comm_done()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
@ -133,11 +139,11 @@ def yield_and_switch_from_compute_to_comm(schedule="default"):
|
||||
ctx.yield_and_switch_from_compute_to_comm()
|
||||
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(schedule="default"):
|
||||
def yield_and_switch_from_comm_to_compute(schedule="default", recv_hook = None):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_and_switch_from_comm_to_compute()
|
||||
ctx.yield_and_switch_from_comm_to_compute(recv_hook=recv_hook)
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
@ -146,6 +152,7 @@ def make_ubatch_contexts(
|
||||
comm_stream: torch.cuda.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
device: Optional[torch.device] = None,
|
||||
enable_async_comms: bool = False,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
@ -175,6 +182,7 @@ def make_ubatch_contexts(
|
||||
num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
enable_async_comms=enable_async_comms,
|
||||
schedule=schedule)
|
||||
ctxs.append(ctx)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user