[Misc] DP : Add ExpertTokensMetadata (#20332)

Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-07-09 20:33:14 -04:00 committed by GitHub
parent b7d9e9416f
commit 805d62ca88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 117 additions and 79 deletions

View File

@ -260,8 +260,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
import deep_gemm as dg
assert hidden_states.ndim == 3
assert self.block_shape is not None
@ -287,7 +290,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
masked_m=expert_num_tokens,
expected_m=expected_m)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)

View File

@ -129,7 +129,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
@ -137,4 +137,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_num_tokens)
workspace2, expert_tokens_meta)

View File

@ -303,11 +303,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,

View File

@ -119,7 +119,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
import deep_gemm as dg
assert self.block_shape is not None

View File

@ -62,8 +62,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
(num_tokens_per_rank, num_tokens_per_rdma_rank,
dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
@ -83,7 +84,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens,
num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
@ -115,7 +116,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
def prepare(
@ -129,8 +136,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@ -149,7 +157,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
@ -159,7 +167,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids,
(expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
@ -176,7 +184,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int,

View File

@ -119,8 +119,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
@ -158,7 +159,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,

View File

@ -505,8 +505,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
@ -587,7 +588,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert b_a1_scale is None or b_a1_scale.ndim == 3
return b_a1, b_a1_scale, tokens_per_expert, None, None
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None)
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(
self,
@ -694,28 +698,19 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else:
return t.to(f32) * group_broadcast(scale, t.shape)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
assert hidden_states.dim() == 3
assert expert_num_tokens is not None
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), (
@ -902,26 +897,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
@ -938,6 +923,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)

View File

@ -1630,7 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
# Check constraints.
if self.use_int4_w4a16:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
):
"""
This function computes the intermediate result of a Mixture of Experts
@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
"""
raise NotImplementedError
@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
expert_tokens_meta=expert_tokens_meta,
)
else:
# The leading output dimension may not be equal to M, so
@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
expert_tokens_meta=expert_tokens_meta,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,

View File

@ -94,8 +94,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
@ -200,7 +201,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3
return expert_x, expert_x_scale, expert_num_tokens, None, None
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize(
self,

View File

@ -38,8 +38,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)

View File

@ -110,7 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
use_deep_gemm = (self.allow_deep_gemm
and _valid_deep_gemm(hidden_states, w1, w2))
@ -135,5 +135,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
expert_tokens_meta,
)