[Misc] ModularKernel : Perform WeightAndReduce inside TritonExperts & DeepGemmExperts (#20725)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-07-15 01:17:16 +05:30 committed by GitHub
parent 8bb43b9c9e
commit c0569dbc82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 203 additions and 157 deletions

View File

@ -260,6 +260,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@ -273,6 +274,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens

View File

@ -129,30 +129,22 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta)
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
apply_router_weight_on_input)

View File

@ -291,26 +291,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
@ -90,8 +90,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
@ -104,9 +103,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = (M_sum, max(N * 2, K))
workspace1 = (M_sum, max(N // 2, K))
workspace2 = (M_sum, max(N, K))
output = (M, topk, K)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
@ -115,6 +114,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@ -128,11 +128,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert self.block_shape is not None
a1q = hidden_states
_, N, K = w1.size()
M, _ = output.size()
num_topk = topk_ids.size(1)
if global_num_experts == -1:
global_num_experts = w1.size(0)
@ -159,11 +162,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)
mm1_out = _resize_cache(workspace13, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
mm1_out = _resize_cache(workspace2, (M_sum, N))
act_out = _resize_cache(workspace13, (M_sum, N // 2))
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K))
mm2_out = _resize_cache(workspace13, (M_sum, K))
perm_out = _resize_cache(workspace2, (M * num_topk, K))
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
mm1_out, expert_ids)
@ -179,7 +183,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
TopKWeightAndReduceContiguous().apply(
output=output,
fused_expert_output=perm_out,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)
def deep_gemm_moe_fp8(

View File

@ -696,15 +696,16 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
return t.to(f32) * group_broadcast(scale, t.shape)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
assert hidden_states.dim() == 3
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
@ -899,15 +900,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (

View File

@ -26,7 +26,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@ -1606,8 +1606,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
@ -1620,9 +1619,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N)
output = (M, topk, K)
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
@ -1631,6 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@ -1644,6 +1644,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.use_int4_w4a16:
@ -1696,37 +1697,39 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2,
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
intermediate_cache2 = _resize_cache(workspace13,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace2,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
@ -1739,15 +1742,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
output,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
None,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
@ -1758,6 +1761,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
ops.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
use_fp8_w8a8: bool,

View File

@ -360,6 +360,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@ -373,6 +374,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
"""
This function computes the intermediate result of a Mixture of Experts
@ -384,6 +386,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations
choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
@ -409,6 +413,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- apply_router_weight_on_input: True if router weights are already
applied on the input. This is relevant if the implementation
chooses to do weight application.
"""
raise NotImplementedError
@ -452,17 +459,21 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
local_num_experts: int, expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata]
) -> torch.Tensor:
def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
a1: torch.Tensor, a1q: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -485,36 +496,49 @@ class FusedMoEModularKernel(torch.nn.Module):
# reuse workspace13 for the output
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(fused_out,
a1q,
w1,
w2,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_out
def _maybe_chunk_fused_experts(
self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_ids: torch.Tensor, activation: str,
global_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata]
self,
a1: torch.Tensor,
a1q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -529,6 +553,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q=a1q,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
@ -540,7 +565,8 @@ class FusedMoEModularKernel(torch.nn.Module):
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta)
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
# Chunking required case
assert num_chunks > 1
@ -557,11 +583,12 @@ class FusedMoEModularKernel(torch.nn.Module):
def slice_input_tensors(
chunk_idx: int
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor]:
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s, e), topk_ids[s:e])
_chunk_scales(a2_scale, s,
e), topk_ids[s:e], topk_weights[s:e])
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
@ -594,7 +621,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = (
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
c_expert_tokens_meta = None
@ -603,23 +630,26 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx),
a1=a1,
a1q=c_a1q,
w1=w1,
w2=w2,
topk_ids=c_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta)
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
a1q=c_a1q,
w1=w1,
w2=w2,
topk_weights=c_topk_weights,
topk_ids=c_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_out
@ -719,6 +749,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q=a1q,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
@ -730,7 +761,8 @@ class FusedMoEModularKernel(torch.nn.Module):
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta)
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,

View File

@ -48,11 +48,18 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
# Relax this if an explicit copy is necessary. Note that,
# if a copy is employed we have to make sure that the
# tensors don't overlap
assert output is None
return fused_expert_output
# Weight application and reduction operations are already done.
if output is None:
return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "
f"But got output={output.size()}, "
f"used_expert_output={fused_expert_output.size()}")
output.copy_(fused_expert_output, non_blocking=True)
return output
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):

View File

@ -122,6 +122,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@ -135,6 +136,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
@ -148,6 +150,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
global_num_experts,
@ -161,4 +164,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13,
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
)