Support Llama 4 for fused_marlin_moe (#20457)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-07-04 16:55:10 +09:00 committed by GitHub
parent 1caca5a589
commit 0e3fe896e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 11 additions and 17 deletions

View File

@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * topk,
@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,

View File

@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,

View File

@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx,

View File

@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
else:

View File

@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,

View File

@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)