activation plumbing for fused_marlin_moe

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-09-02 09:24:24 -04:00
parent 8bd5844989
commit d256cd23c1
5 changed files with 12 additions and 13 deletions

View File

@ -512,8 +512,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"EPLB not supported for `AWQMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -542,6 +540,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,

View File

@ -364,7 +364,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@ -398,8 +397,11 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
assert activation == "silu", "Only SiLU activation is supported."
# FlashInfer fused experts path
if self.fused_experts is not None:
assert is_valid_flashinfer_cutlass_fused_moe(
@ -924,8 +926,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale,
expert_map=expert_map)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -940,6 +940,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
assert self.fused_experts_func is not None
@ -1383,9 +1384,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for "
"`CompressedTensorsWNA16MarlinMoEMethod` yet.")
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -1414,6 +1412,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,

View File

@ -1077,8 +1077,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape=self.quant_config.weight_block_size,
expert_map=expert_map)
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
@ -1093,6 +1091,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert self.block_quant is None

View File

@ -661,8 +661,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -691,6 +689,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,

View File

@ -1370,13 +1370,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
assert activation == "silu", "Only SiLU activation is supported."
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4,
@ -1458,8 +1458,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
assert activation == "silu", "Only SiLU activation is supported."
if self.fused_experts is not None:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS