mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 04:53:33 +08:00
Support Llama 4 for fused_marlin_moe (#20457)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
1caca5a589
commit
0e3fe896e2
@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
quant_type_id: int,
|
quant_type_id: int,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
global_scale1: Optional[torch.Tensor] = None,
|
global_scale1: Optional[torch.Tensor] = None,
|
||||||
@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
moe_block_size=block_size_m,
|
moe_block_size=block_size_m,
|
||||||
top_k=topk,
|
top_k=topk,
|
||||||
mul_topk_weights=False,
|
mul_topk_weights=apply_router_weight_on_input,
|
||||||
is_ep=expert_map is not None,
|
is_ep=expert_map is not None,
|
||||||
b_q_type=quant_type,
|
b_q_type=quant_type,
|
||||||
size_m=M,
|
size_m=M,
|
||||||
@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
moe_block_size=block_size_m,
|
moe_block_size=block_size_m,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
mul_topk_weights=True,
|
mul_topk_weights=not apply_router_weight_on_input,
|
||||||
is_ep=expert_map is not None,
|
is_ep=expert_map is not None,
|
||||||
b_q_type=quant_type,
|
b_q_type=quant_type,
|
||||||
size_m=M * topk,
|
size_m=M * topk,
|
||||||
@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
quant_type_id: int,
|
quant_type_id: int,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
global_scale1: Optional[torch.Tensor] = None,
|
global_scale1: Optional[torch.Tensor] = None,
|
||||||
global_scale2: Optional[torch.Tensor] = None,
|
global_scale2: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Apply router weight on input is not supported for"
|
|
||||||
"fused Marlin MoE method.")
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type_id=self.quant_type.id,
|
quant_type_id=self.quant_type.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_zeros=layer.w13_qzeros,
|
w1_zeros=layer.w13_qzeros,
|
||||||
|
|||||||
@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
global_scale1=layer.w13_weight_scale_2,
|
global_scale1=layer.w13_weight_scale_2,
|
||||||
global_scale2=layer.w2_weight_scale_2,
|
global_scale2=layer.w2_weight_scale_2,
|
||||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
assert not apply_router_weight_on_input, (
|
|
||||||
"Apply router weight on input not supported for Marlin MoE.")
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
|
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
assert not apply_router_weight_on_input, (
|
|
||||||
"Apply router weight on input not supported for Marlin MoE.")
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type_id=self.quant_type.id,
|
quant_type_id=self.quant_type.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
g_idx1=layer.w13_weight_g_idx,
|
g_idx1=layer.w13_weight_g_idx,
|
||||||
|
|||||||
@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
elif self.use_marlin:
|
elif self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
assert not apply_router_weight_on_input, (
|
|
||||||
"Apply router weight on input not supported for Marlin MoE.")
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
if apply_router_weight_on_input:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Apply router weight on input is not supported for "
|
|
||||||
"fused Marlin MoE method.")
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type_id=self.quant_type.id,
|
quant_type_id=self.quant_type.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
g_idx1=layer.w13_g_idx,
|
g_idx1=layer.w13_g_idx,
|
||||||
|
|||||||
@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
global_scale1=layer.w13_weight_scale_2,
|
global_scale1=layer.w13_weight_scale_2,
|
||||||
global_scale2=layer.w2_weight_scale_2,
|
global_scale2=layer.w2_weight_scale_2,
|
||||||
quant_type_id=scalar_types.float4_e2m1f.id,
|
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user