add support for splitting dispatch/combine deepep ll kernels

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-30 20:37:48 +00:00
parent 1c41175b2a
commit 582d301f98
5 changed files with 31 additions and 11 deletions

View File

@ -1937,6 +1937,9 @@ class ParallelConfig:
request is greater than this threshold, microbatching will be used.
Otherwise, the request will be processed in a single batch."""
enable_async_comms: bool = False
"""enable async comms"""
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

View File

@ -304,6 +304,7 @@ class EngineArgs:
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_microbatching: bool = ParallelConfig.enable_microbatching
enable_async_comms: bool = ParallelConfig.enable_async_comms
enable_eplb: bool = ParallelConfig.enable_eplb
num_redundant_experts: int = ParallelConfig.num_redundant_experts
eplb_window_size: int = ParallelConfig.eplb_window_size
@ -640,6 +641,8 @@ class EngineArgs:
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument("--enable-microbatching",
**parallel_kwargs["enable_microbatching"])
parallel_group.add_argument("--enable-async-comms",
**parallel_kwargs["enable_async_comms"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--num-redundant-experts",
@ -1166,6 +1169,7 @@ class EngineArgs:
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_microbatching=self.enable_microbatching,
enable_async_comms=self.enable_async_comms,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,
eplb_window_size=self.eplb_window_size,

View File

@ -127,6 +127,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hidden_size = a1.size(1)
ubatch_ctx = get_current_ubatch_context()
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
do_recv_hook = True if ubatch_ctx is not None and \
ubatch_ctx.enable_async_comms else False
if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, \
@ -147,16 +149,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dispatch
yield_and_switch_from_compute_to_comm(schedule="default")
expert_x, expert_num_tokens, handle, _, _= \
expert_x, expert_num_tokens, handle, _, recv_hook= \
self.buffers[a2a_idx].low_latency_dispatch(a1,
topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=False)
return_recv_hook=do_recv_hook)
self.handles[a2a_idx] = handle
yield_and_switch_from_comm_to_compute(schedule="default")
yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook)
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
@ -178,6 +180,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
ubatch_ctx = get_current_ubatch_context()
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
handle = self.handles[a2a_idx]
do_recv_hook = True if ubatch_ctx is not None and \
ubatch_ctx.enable_async_comms else False
assert handle is not None
combine_topk_weights = topk_weights
@ -187,12 +191,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun) : Enable zero copy mode
yield_and_switch_from_compute_to_comm(schedule="default")
_ = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
_, _, recv_hook = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
topk_ids,
combine_topk_weights,
handle,
async_finish=False,
zero_copy=False,
return_recv_hook=False,
return_recv_hook=do_recv_hook,
out=output)
yield_and_switch_from_comm_to_compute(schedule="default")
yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook)

View File

@ -1656,7 +1656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
comm_stream=self.comm_stream,
compute_stream=compute_stream,
forward_contexts=forward_contexts,
device=self.device)
device=self.device,
enable_async_comms=self.parallel_config.enable_async_comms)
ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):

View File

@ -24,6 +24,7 @@ class UBatchContext:
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
enable_async_comms: bool,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
@ -34,6 +35,7 @@ class UBatchContext:
self.current_stream = compute_stream
self.gpu_comm_done_event = gpu_comm_done_event
self.gpu_compute_done_event = gpu_compute_done_event
self.enable_async_comms = enable_async_comms
self.schedule = schedule
def __enter__(self):
@ -106,10 +108,14 @@ class UBatchContext:
self.update_stream(self.comm_stream)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
def yield_and_switch_from_comm_to_compute(self, recv_hook = None):
assert current_stream() == self.comm_stream
self._signal_comm_done()
if recv_hook is None:
self._signal_comm_done()
self._cpu_yield()
if recv_hook is not None:
recv_hook()
self._signal_comm_done()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
@ -133,11 +139,11 @@ def yield_and_switch_from_compute_to_comm(schedule="default"):
ctx.yield_and_switch_from_compute_to_comm()
def yield_and_switch_from_comm_to_compute(schedule="default"):
def yield_and_switch_from_comm_to_compute(schedule="default", recv_hook = None):
# Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context()
if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute()
ctx.yield_and_switch_from_comm_to_compute(recv_hook=recv_hook)
def make_ubatch_contexts(
@ -146,6 +152,7 @@ def make_ubatch_contexts(
comm_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext],
device: Optional[torch.device] = None,
enable_async_comms: bool = False,
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
@ -175,6 +182,7 @@ def make_ubatch_contexts(
num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
enable_async_comms=enable_async_comms,
schedule=schedule)
ctxs.append(ctx)