[Misc] DP : Add ExpertTokensMetadata (#20332)

Signed-off-by: Varun <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-07-09 20:33:14 -04:00 committed by GitHub
parent b7d9e9416f
commit 805d62ca88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 117 additions and 79 deletions

View File

@ -260,8 +260,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
import deep_gemm as dg import deep_gemm as dg
assert hidden_states.ndim == 3 assert hidden_states.ndim == 3
assert self.block_shape is not None assert self.block_shape is not None
@ -287,7 +290,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
masked_m=expert_num_tokens, masked_m=expert_num_tokens,
expected_m=expected_m) expected_m=expected_m)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens) expert_num_tokens)

View File

@ -129,7 +129,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
experts = (self.batched_deep_gemm_experts experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts) if self.allow_deep_gemm else self.batched_triton_experts)
@ -137,4 +137,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
experts.apply(output, hidden_states, w1, w2, topk_ids, activation, experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, w2_scale, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_num_tokens) workspace2, expert_tokens_meta)

View File

@ -303,11 +303,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i) activation_callable = lambda o, i: self.activation(activation, o, i)
in_dtype = hidden_states.dtype in_dtype = hidden_states.dtype
run_cutlass_moe_fp8( run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable, output, hidden_states, w1, w2, topk_ids, activation_callable,

View File

@ -119,7 +119,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
import deep_gemm as dg import deep_gemm as dg
assert self.block_shape is not None assert self.block_shape is not None

View File

@ -62,8 +62,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
has_scales = token_scales is not None has_scales = token_scales is not None
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, (num_tokens_per_rank, num_tokens_per_rdma_rank,
is_token_in_rank, event) = self.buffer.get_dispatch_layout( dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
num_experts=num_experts, num_experts=num_experts,
previous_event=None, previous_event=None,
@ -83,7 +84,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens, num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids, topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights, topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert # expert_alignment rounds the number of tokens per expert
@ -115,7 +116,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts - 1 if self.rank_expert_offset == 0 else 0, num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset) expert_topk_ids + self.rank_expert_offset)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, # Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def prepare( def prepare(
@ -129,8 +136,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
@ -149,7 +157,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1) a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch( expert_topk_weights) = self._do_dispatch(
tokens=a1q, tokens=a1q,
token_scales=a1q_scale, token_scales=a1q_scale,
@ -159,7 +167,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else: else:
# DeepEP kernels only support dispatching per-token-quant # DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16. # quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids, (expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch( expert_topk_weights) = self._do_dispatch(
tokens=a1, tokens=a1,
token_scales=None, token_scales=None,
@ -176,7 +184,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
per_act_token_quant=False, per_act_token_quant=False,
block_shape=quant_config.block_shape) block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int, def _apply_weights_and_reduce(self, num_tokens: int,

View File

@ -119,8 +119,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
@ -158,7 +159,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, None, None) expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,

View File

@ -505,8 +505,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
@ -587,7 +588,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert b_a1_scale is None or b_a1_scale.ndim == 3 assert b_a1_scale is None or b_a1_scale.ndim == 3
return b_a1, b_a1_scale, tokens_per_expert, None, None expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None)
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize( def finalize(
self, self,
@ -694,28 +698,19 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else: else:
return t.to(f32) * group_broadcast(scale, t.shape) return t.to(f32) * group_broadcast(scale, t.shape)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor],
w1: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2: torch.Tensor, w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
topk_ids: torch.Tensor, w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
activation: str, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
global_num_experts: int, workspace2: torch.Tensor,
expert_map: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
assert hidden_states.dim() == 3 assert hidden_states.dim() == 3
assert expert_num_tokens is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
num_local_experts = w1.size(0) num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), ( assert num_local_experts == w1.size(0), (
@ -902,26 +897,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dp, K) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply( def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
self, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
output: torch.Tensor, activation: str, global_num_experts: int,
hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor],
w1: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2: torch.Tensor, w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
topk_ids: torch.Tensor, w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
activation: str, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
global_num_experts: int, workspace2: torch.Tensor,
expert_map: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (
@ -938,6 +923,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert hidden_states.dtype in [ assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
] ]
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids) hidden_states, w1, w2, topk_ids)

View File

@ -1630,7 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Optional, final
@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts", BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
""" """
Perform any quantization (and/or) dispatching needed Perform any quantization (and/or) dispatching needed
for this kernel. for this kernel.
@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of: Returns a tuple of:
- quantized + dispatched a. - quantized + dispatched a.
- quantized + dispatched a1_scales. - quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the - Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert. number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs - Optional dispatched expert topk IDs
- Optional dispatched expert topk weight - Optional dispatched expert topk weight
@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata],
): ):
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
must be large enough to hold output of either MoE gemm. must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation - workspace2 (torch.Tensor): A scratch tensor used for the activation
function. function.
- expert_num_tokens: An optional tensor containing the number of tokens - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
assigned to each expert when using batched experts format input. ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
""" """
raise NotImplementedError raise NotImplementedError
@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1,
a1_scale, a1_scale,
@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale, a2_scale=a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_num_tokens=expert_num_tokens, expert_tokens_meta=expert_tokens_meta,
) )
else: else:
# The leading output dimension may not be equal to M, so # The leading output dimension may not be equal to M, so
@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=curr_a2_scale, a2_scale=curr_a2_scale,
workspace13=workspace13, workspace13=workspace13,
workspace2=workspace2, workspace2=workspace2,
expert_num_tokens=expert_num_tokens, expert_tokens_meta=expert_tokens_meta,
) )
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(output, fused_out, topk_weights,

View File

@ -94,8 +94,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
@ -200,7 +201,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
assert expert_x_scale.ndim == 3 assert expert_x_scale.ndim == 3
return expert_x, expert_x_scale, expert_num_tokens, None, None expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize( def finalize(
self, self,

View File

@ -38,8 +38,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)

View File

@ -110,7 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
): ):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and _valid_deep_gemm(hidden_states, w1, w2)) and _valid_deep_gemm(hidden_states, w1, w2))
@ -135,5 +135,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale, a2_scale,
workspace13, workspace13,
workspace2, workspace2,
expert_num_tokens, expert_tokens_meta,
) )