mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:37:13 +08:00
Support Llama 4 for fused_marlin_moe (#20457)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
1caca5a589
commit
0e3fe896e2
@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
global_scale1: Optional[torch.Tensor] = None,
|
||||
@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
mul_topk_weights=apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
mul_topk_weights=not apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M * topk,
|
||||
@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
global_scale1: Optional[torch.Tensor] = None,
|
||||
global_scale2: Optional[torch.Tensor] = None,
|
||||
|
||||
@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
|
||||
@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
|
||||
@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
|
||||
@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
|
||||
@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_scale1=layer.w13_weight_scale_2,
|
||||
global_scale2=layer.w2_weight_scale_2,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user