mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Misc] DP : Add ExpertTokensMetadata (#20332)
Signed-off-by: Varun <vsundarr@redhat.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
parent
b7d9e9416f
commit
805d62ca88
@ -260,8 +260,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
import deep_gemm as dg
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.block_shape is not None
|
||||
@ -287,7 +290,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
|
||||
assert expert_num_tokens is not None
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
||||
expert_num_tokens)
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
experts = (self.batched_deep_gemm_experts
|
||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||
@ -137,4 +137,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||
workspace2, expert_num_tokens)
|
||||
workspace2, expert_tokens_meta)
|
||||
|
||||
@ -303,11 +303,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
if expert_tokens_meta is not None:
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
|
||||
@ -119,7 +119,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
import deep_gemm as dg
|
||||
assert self.block_shape is not None
|
||||
|
||||
@ -62,8 +62,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
has_scales = token_scales is not None
|
||||
|
||||
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
|
||||
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
|
||||
(num_tokens_per_rank, num_tokens_per_rdma_rank,
|
||||
dispatch_expert_num_tokens, is_token_in_rank,
|
||||
event) = self.buffer.get_dispatch_layout(
|
||||
topk_idx=rank_topk_ids,
|
||||
num_experts=num_experts,
|
||||
previous_event=None,
|
||||
@ -83,7 +84,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=expert_num_tokens,
|
||||
num_tokens_per_expert=dispatch_expert_num_tokens,
|
||||
topk_idx=rank_topk_ids,
|
||||
topk_weights=rank_topk_weights,
|
||||
# expert_alignment rounds the number of tokens per expert
|
||||
@ -115,7 +116,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts - 1 if self.rank_expert_offset == 0 else 0,
|
||||
expert_topk_ids + self.rank_expert_offset)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
# Makes a GPU-CPU copy.
|
||||
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
|
||||
# on GPU.
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
|
||||
expert_num_tokens_per_expert_list, device=expert_x.device)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def prepare(
|
||||
@ -129,8 +136,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@ -149,7 +157,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
if a1q_scale is not None and a1q_scale.numel() == 1:
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
@ -159,7 +167,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
else:
|
||||
# DeepEP kernels only support dispatching per-token-quant
|
||||
# quantization. dispatch in bfloat16.
|
||||
(expert_x, _, expert_num_tokens, expert_topk_ids,
|
||||
(expert_x, _, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1,
|
||||
token_scales=None,
|
||||
@ -176,7 +184,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def _apply_weights_and_reduce(self, num_tokens: int,
|
||||
|
||||
@ -119,8 +119,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
|
||||
@ -158,7 +159,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
||||
@ -505,8 +505,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0)
|
||||
@ -587,7 +588,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
assert b_a1_scale is None or b_a1_scale.ndim == 3
|
||||
|
||||
return b_a1, b_a1_scale, tokens_per_expert, None, None
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None)
|
||||
|
||||
return b_a1, b_a1_scale, expert_tokens_meta, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
@ -694,28 +698,19 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
else:
|
||||
return t.to(f32) * group_broadcast(scale, t.shape)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
):
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
|
||||
assert hidden_states.dim() == 3
|
||||
assert expert_num_tokens is not None
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
assert num_local_experts == w1.size(0), (
|
||||
@ -902,26 +897,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
):
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
@ -938,6 +923,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||
]
|
||||
assert expert_tokens_meta is not None
|
||||
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
@ -1630,7 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional, final
|
||||
@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
|
||||
BatchedExperts = "batched_experts",
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertTokensMetadata:
|
||||
"""
|
||||
Metadata regarding expert-token routing.
|
||||
"""
|
||||
expert_num_tokens: torch.Tensor
|
||||
expert_num_tokens_cpu: Optional[torch.Tensor]
|
||||
|
||||
@staticmethod
|
||||
def make_from_list(expert_num_tokens_list: list[int],
|
||||
device: str) -> "ExpertTokensMetadata":
|
||||
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens_cpu.to(device,
|
||||
non_blocking=True),
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu)
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@ -114,8 +135,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@ -134,7 +156,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
@ -318,7 +341,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
@ -351,8 +374,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
function.
|
||||
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||
assigned to each expert when using batched experts format input.
|
||||
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
|
||||
ExpertTokensMetadata object containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -458,7 +483,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
@ -542,7 +567,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
else:
|
||||
# The leading output dimension may not be equal to M, so
|
||||
@ -589,7 +614,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=curr_a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
|
||||
@ -94,8 +94,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
@ -200,7 +201,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||
assert expert_x_scale.ndim == 3
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -38,8 +38,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@ -110,7 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
use_deep_gemm = (self.allow_deep_gemm
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
@ -135,5 +135,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
expert_tokens_meta,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user