diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3e79a1a8c24b2..4e3e15a35ada2 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -170,8 +170,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): "w1_scale and w2_scale must not " "be None for FlashInferExperts") - assert not apply_router_weight_on_input - quant_scales = [ a1_gscale, w1_scale.view(torch.int32), diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 7fdb465c459d9..36aca8cf74b6d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -60,7 +60,12 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - assert not apply_router_weight_on_input + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) (a1_gscale, use_dp, local_tokens) = extract_required_args( extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 147b275eaf525..bed502226716c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1299,8 +1299,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): output2_scale_scalar=layer.g2_alphas.data, num_experts=global_num_experts, top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, + n_group=num_expert_group + if num_expert_group is not None else 0, + topk_group=topk_group if topk_group is not None else 0, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts,