[Kernels] Overlap shared experts with combine instead of dispatch (#24254)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-09-18 00:10:21 -04:00 committed by GitHub
parent 027d37df38
commit dc2979c585
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 203 additions and 36 deletions

View File

@ -240,7 +240,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config)
return receiver()
def finalize(
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
@ -248,7 +248,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
do_async: bool,
) -> Optional[Callable]:
assert self.handle is not None
@ -271,7 +272,46 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=False,
async_finish=do_async,
allocate_on_comm_stream=False)
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
if do_async:
def _receiver():
event.current_stream_wait()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return lambda: _receiver()
else:
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
receiver = self._finalize(output, fused_expert_output, topk_weights,
topk_ids, apply_router_weight_on_input,
weight_and_reduce_impl, True)
assert receiver is not None
return receiver
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(output, fused_expert_output, topk_weights, topk_ids,
apply_router_weight_on_input, weight_and_reduce_impl,
False)

View File

@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape)
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
dbo_maybe_run_recv_hook)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
@ -198,7 +197,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hook()
return receiver()
def finalize(
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
@ -206,13 +205,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
do_async: bool,
) -> Optional[Callable]:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
a2a_idx = dbo_current_ubatch_id()
do_recv_hook = dbo_enabled()
do_recv_hook = dbo_enabled() or do_async
handle = self.handles[a2a_idx]
assert handle is not None
@ -232,6 +232,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
zero_copy=False,
return_recv_hook=do_recv_hook,
out=output)
if recv_hook is not None:
dbo_register_recv_hook(recv_hook)
dbo_yield()
return recv_hook
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
recv_hook = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=True,
)
assert recv_hook is not None
return recv_hook
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=False,
)

View File

@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return False
@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
- output: The output tensor, written in place. Must be (M, K) shape.
- fused_expert_output: The unweighted, unreduced output of the fused
experts, it will have (M, topk, K) shape.
- topk_weights: The weights to be applied to the fused_experts_output.
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.
receiver = obj.finalize_async(output, ...)
... output not valid yet ...
receiver()
... output valid here ...
is equivalent to:
obj.finalize(output, ...)
"""
raise NotImplementedError
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module):
"""
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
if inplace and self.shared_experts is None:
output = a1
else:
output = torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
shared_output: torch.Tensor
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert not dbo_enabled()
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config,
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
@ -900,16 +931,42 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input=apply_router_weight_on_input,
)
self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
shared_output: Optional[torch.Tensor] = None
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
else:
recv_hook = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
assert recv_hook is not None
dbo_register_recv_hook(recv_hook)
dbo_yield()
if not dbo_enabled():
recv_hook()
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output

View File

@ -272,7 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
hook()
return receiver()
def finalize(
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
@ -280,7 +280,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
) -> Callable:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
@ -303,8 +303,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
self.a2a.combine(out_tokens=output,
indices=topk_ids.view(dtype=torch.uint32),
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)
bound_m=bound_m,
do_send=True,
do_recv=False)
return lambda: self.a2a.combine(out_tokens=output,
indices=topk_ids_u32,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=False,
do_recv=True)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
receiver = self.finalize_async(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
)
receiver()