mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 00:42:10 +08:00
[Kernels] Overlap shared experts with combine instead of dispatch (#24254)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
027d37df38
commit
dc2979c585
@ -240,7 +240,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
def _finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
@ -248,7 +248,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
|
||||
assert self.handle is not None
|
||||
|
||||
@ -271,7 +272,46 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights=None,
|
||||
config=self._get_combine_config(),
|
||||
previous_event=None,
|
||||
async_finish=False,
|
||||
async_finish=do_async,
|
||||
allocate_on_comm_stream=False)
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
|
||||
if do_async:
|
||||
|
||||
def _receiver():
|
||||
event.current_stream_wait()
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
|
||||
return lambda: _receiver()
|
||||
else:
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
return None
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
receiver = self._finalize(output, fused_expert_output, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input,
|
||||
weight_and_reduce_impl, True)
|
||||
assert receiver is not None
|
||||
return receiver
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
self._finalize(output, fused_expert_output, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input, weight_and_reduce_impl,
|
||||
False)
|
||||
|
||||
@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input, normalize_batched_scales_shape)
|
||||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
dbo_maybe_run_recv_hook)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
@ -198,7 +197,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
def _finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
@ -206,13 +205,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
do_recv_hook = dbo_enabled()
|
||||
do_recv_hook = dbo_enabled() or do_async
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
@ -232,6 +232,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
zero_copy=False,
|
||||
return_recv_hook=do_recv_hook,
|
||||
out=output)
|
||||
if recv_hook is not None:
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
|
||||
return recv_hook
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
recv_hook = self._finalize(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
do_async=True,
|
||||
)
|
||||
assert recv_hook is not None
|
||||
return recv_hook
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
self._finalize(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
do_async=False,
|
||||
)
|
||||
|
||||
@ -209,7 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not this class implements prepare_async.
|
||||
Indicates whether or not this class implements prepare_async and
|
||||
finalize_async.
|
||||
"""
|
||||
return False
|
||||
|
||||
@ -275,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
- output: The output tensor, written in place. Must be (M, K) shape.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
- topk_weights: The weights to be applied to the fused_experts_output.
|
||||
- topk_ids: The topk_ids.
|
||||
- apply_router_weight_on_input: When False, apply the weights to
|
||||
fused_expert_output.
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `finalize`, e.g.
|
||||
|
||||
receiver = obj.finalize_async(output, ...)
|
||||
... output not valid yet ...
|
||||
receiver()
|
||||
... output valid here ...
|
||||
|
||||
is equivalent to:
|
||||
|
||||
obj.finalize(output, ...)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
@ -814,23 +851,20 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
|
||||
a1 = hidden_states
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
if inplace and self.shared_experts is None:
|
||||
output = a1
|
||||
else:
|
||||
output = torch.zeros_like(a1)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
assert not dbo_enabled()
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
@ -854,9 +888,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
@ -900,16 +931,42 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
shared_output: Optional[torch.Tensor] = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
recv_hook = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
assert recv_hook is not None
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
recv_hook()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
@ -272,7 +272,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def finalize(
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
@ -280,7 +280,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
) -> Callable:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
@ -303,8 +303,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
|
||||
|
||||
self.a2a.combine(out_tokens=output,
|
||||
indices=topk_ids.view(dtype=torch.uint32),
|
||||
indices=topk_ids_u32,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m)
|
||||
bound_m=bound_m,
|
||||
do_send=True,
|
||||
do_recv=False)
|
||||
|
||||
return lambda: self.a2a.combine(out_tokens=output,
|
||||
indices=topk_ids_u32,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=False,
|
||||
do_recv=True)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
receiver = self.finalize_async(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
)
|
||||
receiver()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user