From 31b96d1c643c5866dc080b57a71693de1b83cfc6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 10 Jul 2025 04:53:38 +0900 Subject: [PATCH] Support Llama 4 for cutlass_moe_fp4 (#20453) Signed-off-by: mgoin --- .../layers/fused_moe/cutlass_moe.py | 37 ++++++--- .../compressed_tensors_moe.py | 40 +++++----- .../layers/quantization/modelopt.py | 77 ++++++++----------- 3 files changed, 80 insertions(+), 74 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index de588d512739..3b39b3b17ba0 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -411,13 +411,23 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, - w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, - w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, - w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, - w2_alphas: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, - device: torch.device): +def cutlass_moe_fp4(a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + apply_router_weight_on_input: bool = False): """ MoE implementation for FP4 Inputs @@ -480,6 +490,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + if apply_router_weight_on_input: + # TODO: this only works for topK=1, will need to update for topK>1 + assert num_topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a.mul_(topk_weights.to(out_dtype)) + # problem shapes should have [m, n, k] # Note that problem sizes are based on logical number of elements. ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, @@ -517,8 +533,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, del int_fp4, int_blockscale c2 = ops.shuffle_rows(c2, c_map) - out = (c2.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).half()).sum(dim=1) + if not apply_router_weight_on_input: + out = (c2.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1) + else: + out = c2.view(m, num_topk, k).sum(dim=1) return out.to(dtype=out_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7aeb1cc7d84c..c17a390dba58 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -295,6 +295,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -326,10 +327,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): global_num_experts=global_num_experts, expert_map=expert_map) - assert activation == "silu", "Only SiLU activation is supported." - assert not apply_router_weight_on_input, ( - "Router weight on input is not " - "supported for CompressedTensorsW4A4MoeMethod.") assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "CompressedTensorsW4A4MoeMethod.") @@ -339,22 +336,25 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): # Cutlass moe takes in activations in BF16/Half precision # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4(a=x, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device).to(x.dtype) + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device, + apply_router_weight_on_input=apply_router_weight_on_input).to( + x.dtype) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9db875330230..2295c0e5fe9f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -673,21 +673,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + assert activation == "silu", "Only SiLU activation is supported." + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) if self.use_marlin: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -704,44 +704,31 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_num_experts=global_num_experts, expert_map=expert_map) - assert activation == "silu", "Only SiLU activation is supported." - assert not apply_router_weight_on_input, ( - "Router weight on input is not " - "supported for ModelOptNvFp4FusedMoE.") assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "ModelOptNvFp4FusedMoE.") - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) # Cutlass moe takes in activations in BF16/Half precision # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4(a=x, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device).to(x.dtype) + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device, + apply_router_weight_on_input=apply_router_weight_on_input).to( + x.dtype)