diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index db7e50eff72bc..296743dbfa041 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -17,7 +17,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16, cutlass_fp4_supported) + CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 90b45e32a688d..bc348df84d016 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7da52ce6ff8c8..8f69636dda7bf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -27,8 +27,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - cutlass_fp4_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported, swizzle_blockscale) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -193,29 +193,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w2_input_scale, extra_weight_attrs) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # From packed to weight @@ -243,13 +220,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): return # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter( - self.swizzle_blockscale(layer.w13_weight_scale), - requires_grad=False) + layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w13_weight_scale), + requires_grad=False) - layer.w2_blockscale_swizzled = torch.nn.Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), - requires_grad=False) + layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w2_weight_scale), + requires_grad=False) # w13 w13_input_global_scale = layer.w13_input_global_scale.max( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 81611ed07aaa4..38866586ae29e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -9,8 +9,7 @@ from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm._custom_ops import (cutlass_scaled_fp4_mm, - cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.distributed import get_ep_group from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig @@ -28,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) + GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, @@ -667,14 +666,6 @@ class ModelOptNvFp4Config(QuantizationConfig): return None -def cutlass_fp4_supported() -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int() - return cutlass_scaled_mm_supports_fp4(capability) - - class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. @@ -772,29 +763,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): layer.register_parameter("weight_scale", weight_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer: Module) -> None: # global scales: @@ -814,7 +782,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) @@ -1060,29 +1028,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): weight_loader=weight_loader) layer.register_parameter("w2_input_scale", w2_input_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 # The FlashInfer Cutlass fused MoE kernel expects the combined weights @@ -1128,8 +1073,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Blockscale must be represented as FP8-E4M3") - w13_blockscale_swizzled = self.swizzle_blockscale( - layer.w13_weight_scale) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, requires_grad=False) @@ -1151,7 +1095,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Blockscale must be represented as FP8-E4M3") - w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index fb3287d3b89e6..8648771cb0177 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -2,13 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 -from vllm.platforms import current_platform from vllm.scalar_type import scalar_types __all__ = [ - "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", - "cutlass_fp4_supported" + "break_fp4_bytes", + "dequantize_to_dtype", + "ref_nvfp4_quant", ] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() @@ -17,14 +16,6 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], dtype=torch.float32) -def cutlass_fp4_supported() -> bool: - if not current_platform.is_cuda(): - return False - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int() - return cutlass_scaled_mm_supports_fp4(capability) - - def break_fp4_bytes(a, dtype): assert a.dtype == torch.uint8 m, n = a.shape diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 54361a2323c28..428e9e99aa881 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional import numpy import torch +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.model_executor.layers.quantization.qqq import ( MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -592,3 +594,56 @@ def awq_pack( q_w = q_w.reshape((-1, size_n)).contiguous() return pack_cols(q_w, num_bits, size_k, size_n) + + +def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: + """ + Pad and block-interleave the FP4 block-scales so that they match the data + layout expected by the CUTLASS / FlashInfer kernels. + + Parameters + ---------- + scale: torch.Tensor + + Returns + ------- + torch.Tensor + The swizzled tensor with the same logical shape as *scale*. + """ + assert scale.dtype == torch.float8_e4m3fn, ( + "swizzle_blockscale expects the input tensor to be in " + "torch.float8_e4m3fn format.") + + scale_ndim = scale.ndim + if scale_ndim == 2: + scale = scale.unsqueeze(0) # (1, M, K) + assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales." + + B, M, K = scale.shape + + def _round_up(x: int, m: int) -> int: + return (x + m - 1) // m * m + + M_padded = _round_up(M, 128) + K_padded = _round_up(K, 4) + + padded = torch.zeros((B, M_padded, K_padded), + dtype=scale.dtype, + device=scale.device) + padded[:B, :M, :K] = scale + + # Reshape / permute to the layout required by the kernel. + padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4) + swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() + + if scale_ndim == 2: + return swizzled.reshape(M, K) + return swizzled.reshape(B, M, K) + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability)