mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-15 03:37:06 +08:00
[Misc] ModularKernel : Perform WeightAndReduce inside TritonExperts & DeepGemmExperts (#20725)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
8bb43b9c9e
commit
c0569dbc82
@ -260,6 +260,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@ -273,6 +274,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
@ -129,30 +129,22 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return self.batched_triton_experts.workspace_shapes(
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
experts = (self.batched_deep_gemm_experts
|
||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||
assert experts is not None
|
||||
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||
workspace2, expert_tokens_meta)
|
||||
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
activation, global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||
workspace2, expert_tokens_meta,
|
||||
apply_router_weight_on_input)
|
||||
|
||||
@ -291,26 +291,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
):
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
@ -90,8 +90,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
@ -104,9 +103,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_m = self.block_shape[0]
|
||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
workspace1 = (M_sum, max(N * 2, K))
|
||||
workspace1 = (M_sum, max(N // 2, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M, topk, K)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
@ -115,6 +114,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@ -128,11 +128,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert self.block_shape is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
M, _ = output.size()
|
||||
num_topk = topk_ids.size(1)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = w1.size(0)
|
||||
@ -159,11 +162,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Note: M_sum is different than the pre-permuted shape of a1q.
|
||||
M_sum = a1q.size(0)
|
||||
|
||||
mm1_out = _resize_cache(workspace13, (M_sum, N))
|
||||
act_out = _resize_cache(workspace2, (M_sum, N // 2))
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
mm1_out = _resize_cache(workspace2, (M_sum, N))
|
||||
act_out = _resize_cache(workspace13, (M_sum, N // 2))
|
||||
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
|
||||
(M_sum, N // 2))
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
mm2_out = _resize_cache(workspace13, (M_sum, K))
|
||||
perm_out = _resize_cache(workspace2, (M * num_topk, K))
|
||||
|
||||
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
|
||||
mm1_out, expert_ids)
|
||||
@ -179,7 +183,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
|
||||
mm2_out, expert_ids)
|
||||
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
|
||||
|
||||
TopKWeightAndReduceContiguous().apply(
|
||||
output=output,
|
||||
fused_expert_output=perm_out,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
|
||||
@ -696,15 +696,16 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return t.to(f32) * group_broadcast(scale, t.shape)
|
||||
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
assert hidden_states.dim() == 3
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
@ -899,15 +900,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
@ -1606,8 +1606,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -1620,9 +1619,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1 = (M, topk, max(N * 2, K))
|
||||
workspace2 = (M, topk, N)
|
||||
output = (M, topk, K)
|
||||
workspace1 = (M, topk, max(N // 2, K))
|
||||
workspace2 = (M, topk, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
@ -1631,6 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@ -1644,6 +1644,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
@ -1696,37 +1697,39 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
raise ValueError(
|
||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13,
|
||||
# Note that the output tensor might be in workspace1
|
||||
intermediate_cache1 = _resize_cache(workspace2,
|
||||
(num_tokens, top_k_num, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
intermediate_cache2 = _resize_cache(workspace13,
|
||||
(num_tokens * top_k_num, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace2,
|
||||
(num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
invoke_fused_moe_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
None, # topk_weights
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False, # mul_routed_weights
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
@ -1739,15 +1742,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
output,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
None,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
@ -1758,6 +1761,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
ops.moe_sum(intermediate_cache3, output)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
|
||||
@ -360,6 +360,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@ -373,6 +374,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
@ -384,6 +386,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights: A map of row to expert weights. Some implementations
|
||||
choose to do weight application.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@ -409,6 +413,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
ExpertTokensMetadata object containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- apply_router_weight_on_input: True if router weights are already
|
||||
applied on the input. This is relevant if the implementation
|
||||
chooses to do weight application.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -452,17 +459,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
local_num_experts: int, expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata]
|
||||
) -> torch.Tensor:
|
||||
def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
|
||||
a1: torch.Tensor, a1q: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
@ -485,36 +496,49 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_ids: torch.Tensor, activation: str,
|
||||
global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata]
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
@ -529,6 +553,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
@ -540,7 +565,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
@ -557,11 +583,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor]:
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s, e), topk_ids[s:e])
|
||||
_chunk_scales(a2_scale, s,
|
||||
e), topk_ids[s:e], topk_weights[s:e])
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
@ -594,7 +621,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = (
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
@ -603,23 +630,26 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=c_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta)
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=c_topk_weights,
|
||||
topk_ids=c_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_out
|
||||
|
||||
@ -719,6 +749,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
@ -730,7 +761,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output, fused_out, topk_weights, topk_ids,
|
||||
|
||||
@ -48,11 +48,18 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
# Relax this if an explicit copy is necessary. Note that,
|
||||
# if a copy is employed we have to make sure that the
|
||||
# tensors don't overlap
|
||||
assert output is None
|
||||
return fused_expert_output
|
||||
# Weight application and reduction operations are already done.
|
||||
if output is None:
|
||||
return fused_expert_output
|
||||
|
||||
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
|
||||
# tensor.
|
||||
assert output.size() == fused_expert_output.size(), (
|
||||
"output shape is expected to match the fused_expert_output shape. "
|
||||
f"But got output={output.size()}, "
|
||||
f"used_expert_output={fused_expert_output.size()}")
|
||||
output.copy_(fused_expert_output, non_blocking=True)
|
||||
return output
|
||||
|
||||
|
||||
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
|
||||
|
||||
@ -122,6 +122,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
@ -135,6 +136,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
use_deep_gemm = (self.allow_deep_gemm
|
||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||
@ -148,6 +150,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
@ -161,4 +164,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user