From 6f2f53a82dc9703abb389761bf9931e3c9a5a75b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 30 Jun 2025 00:05:40 +0200 Subject: [PATCH] [Quantization] Add compressed-tensors NVFP4 MoE Support (#19990) Signed-off-by: Dipika Sikka Signed-off-by: Dipika --- tests/quantization/test_compressed_tensors.py | 6 +- vllm/model_executor/layers/fused_moe/layer.py | 4 +- .../compressed_tensors/compressed_tensors.py | 4 +- .../compressed_tensors_moe.py | 275 +++++++++++++++++- .../schemes/compressed_tensors_w4a4_nvfp4.py | 13 +- .../utils/nvfp4_emulation_utils.py | 15 +- 6 files changed, 295 insertions(+), 22 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 516bf4513816a..3646ad6c481b0 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsWNA16, cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform @@ -668,8 +668,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) if isinstance(qkv_proj.scheme, scheme) or isinstance( - qkv_proj.scheme, CompressedTensorsW4A16Fp4 - ) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported(): + qkv_proj.scheme, + CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported(): assert True else: raise AssertionError("FP4 Scheme Mismatch") diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5408ef1f75e89..e6f555d315d8e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1246,6 +1246,7 @@ class FusedMoE(torch.nn.Module): param.materialize(final_shape, dtype=loaded_weight.dtype) expert_data = param.data if full_load else param.data[expert_id] + # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only @@ -1273,6 +1274,7 @@ class FusedMoE(torch.nn.Module): tp_rank=self.tp_rank) return True if return_success else None + # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: if ('weight_scale_2' in weight_name or 'input_scale' in weight_name): @@ -1289,7 +1291,7 @@ class FusedMoE(torch.nn.Module): tp_rank=self.tp_rank) return True if return_success else None - # Case weight scales, zero_points and offset + # Case weight scales, zero_points and offset, weight/input global scales if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): # load the weight scales and zp based on the quantization scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d21abb2741a29..4f87b2a44f0ac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -33,6 +33,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -375,7 +377,7 @@ class CompressedTensorsConfig(QuantizationConfig): if is_activation_quantization_format(self.quant_format): if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if CompressedTensorsW4A4Fp4.cutlass_fp4_supported( + if cutlass_fp4_supported( ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 4a19473005708..fa4ce5668091b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -21,8 +21,12 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -46,12 +50,11 @@ class GPTQMarlinState(Enum): __all__ = [ - "CompressedTensorsMoEMethod", - "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod" ] @@ -84,6 +87,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4MoeMethod() elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): @@ -95,6 +100,268 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") +class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): + + def __init__(self): + self.use_marlin = not cutlass_fp4_supported() + self.group_size = 16 + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # From packed to weight + layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data, + requires_grad=False) + + layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, + requires_grad=False) + + if not torch.allclose(layer.w13_weight_global_scale[:, 0], + layer.w13_weight_global_scale[:, 1]): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected.") + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False) + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + return + + # swizzle weight scales + layer.w13_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w13_weight_scale), + requires_grad=False) + + layer.w2_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), + requires_grad=False) + + # w13 + w13_input_global_scale = layer.w13_input_global_scale.max( + dim=1).values.to(torch.float32) + + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False) + + # w2 + layer.g2_alphas = torch.nn.Parameter( + ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( + torch.float32), + requires_grad=False) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (layer.w2_input_global_scale), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsW4A4MoeMethod` yet.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + if self.use_marlin: + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for CompressedTensorsW4A4MoeMethod.") + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod.") + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) + + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index ec1d4a6c0efae..65cbc49d2640a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -5,8 +5,7 @@ import torch from torch.nn.parameter import Parameter import vllm.envs as envs -from vllm._custom_ops import (cutlass_scaled_fp4_mm, - cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) @@ -15,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -33,15 +31,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): return 80 return 100 - @classmethod - def cutlass_fp4_supported(cls) -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501 - ) - return cutlass_scaled_mm_supports_fp4(capability) - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index d5ce6d7ad757a..fb3287d3b89e6 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] +__all__ = [ + "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", + "cutlass_fp4_supported" +] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() @@ -12,6 +17,14 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], dtype=torch.float32) +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + def break_fp4_bytes(a, dtype): assert a.dtype == torch.uint8 m, n = a.shape