diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 40b76994f412c..1988c73ba7e2e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, + apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None, @@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, topk_weights, moe_block_size=block_size_m, top_k=topk, - mul_topk_weights=False, + mul_topk_weights=apply_router_weight_on_input, is_ep=expert_map is not None, b_q_type=quant_type, size_m=M, @@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, topk_weights, moe_block_size=block_size_m, top_k=1, - mul_topk_weights=True, + mul_topk_weights=not apply_router_weight_on_input, is_ep=expert_map is not None, b_q_type=quant_type, size_m=M * topk, @@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, + apply_router_weight_on_input: bool = False, global_num_experts: int = -1, global_scale1: Optional[torch.Tensor] = None, global_scale2: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index aff54bc495b2d..0fdded0b5a7fc 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase): assert activation == "silu", "Only SiLU activation is supported." - if apply_router_weight_on_input: - raise NotImplementedError( - "Apply router weight on input is not supported for" - "fused Marlin MoE method.") - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, quant_type_id=self.quant_type.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, w1_zeros=layer.w13_qzeros, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b6dab4320ee3c..48eeda5450b0b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): global_scale1=layer.w13_weight_scale_2, global_scale2=layer.w2_weight_scale_2, quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) @@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") - assert not apply_router_weight_on_input, ( - "Apply router weight on input not supported for Marlin MoE.") return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) @@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") - assert not apply_router_weight_on_input, ( - "Apply router weight on input not supported for Marlin MoE.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): topk_weights, topk_ids, quant_type_id=self.quant_type.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, g_idx1=layer.w13_weight_g_idx, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fef8f4d46d8ed..a7d221780beec 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): elif self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") - assert not apply_router_weight_on_input, ( - "Apply router weight on input not supported for Marlin MoE.") return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) else: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 48ab04c9ab37f..9bed5e2e48898 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): "EPLB not supported for `GPTQMarlinMoEMethod` yet.") assert activation == "silu", "Only SiLU activation is supported." - if apply_router_weight_on_input: - raise NotImplementedError( - "Apply router weight on input is not supported for " - "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, quant_type_id=self.quant_type.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, g_idx1=layer.w13_g_idx, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a10911b84afc4..9db875330230a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_scale1=layer.w13_weight_scale_2, global_scale2=layer.w2_weight_scale_2, quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map)