diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 2c07fe29fb0e..d68aa22bed0c 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -14,9 +14,10 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform @@ -651,9 +652,13 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): assert output -def test_compressed_tensors_nvfp4a16(vllm_runner): - # run weight only example - model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16" +@pytest.mark.parametrize( + "args", + [("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", + CompressedTensorsW4A16Fp4), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4)]) +def test_compressed_tensors_nvfp4(vllm_runner, args): + model, scheme = args with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): @@ -662,7 +667,7 @@ def test_compressed_tensors_nvfp4a16(vllm_runner): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4) + assert isinstance(qkv_proj.scheme, scheme) assert qkv_proj.scheme.group_size == 16 llm.apply_model(check_model) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1ee4617e1054..28c62fc5e58b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -24,10 +24,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, - CompressedTensorsScheme, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -218,6 +218,26 @@ class CompressedTensorsConfig(QuantizationConfig): else: return False + def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + + if weight_quant is None or input_quant is None: + return False + + is_tensor_group_quant = (weight_quant.strategy + == QuantizationStrategy.TENSOR_GROUP.value + and input_quant.strategy + == QuantizationStrategy.TENSOR_GROUP.value) + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + is_group_size_16 = (weight_quant.group_size == 16 + and input_quant.group_size == 16) + is_float_type = (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT.value) + is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 + + return (is_tensor_group_quant and is_float_type and is_4_bits + and is_group_size_16 and is_symmetric) + def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): @@ -353,6 +373,9 @@ class CompressedTensorsConfig(QuantizationConfig): actorder=weight_quant.actorder) if is_activation_quantization_format(self.quant_format): + if self._is_fp4a4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4Fp4() + if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( CompressedTensorsW8A8Fp8.get_min_capability(), error=False) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 25924c733e76..6e4e75df7604 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 @@ -18,5 +19,6 @@ __all__ = [ "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24", "CompressedTensorsW4A16Fp4" + "CompressedTensors24", "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A4Fp4" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 000000000000..9899db3243a4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._custom_ops import (cutlass_scaled_fp4_mm, + cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + dequantize_to_dtype, ref_nvfp4_quant) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + + def __init__(self): + self.group_size = 16 + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + if not self.cutlass_nvfp4_supported: + logger.warning("Current platform does not support cutlass NVFP4." + " Running emulations.") + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def run_nvfp4_emulations(self, x: torch.Tensor, layer): + x_m, x_k = x.shape + output_dtype = x.dtype + + # quantize input to (FP4 and interleaved block scale) + x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale, + self.group_size) + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size) + x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + # dequantize weight + w_fp4 = layer.weight.data.view(torch.uint8) + w_blockscale = layer.weight_scale_swizzled.data + w_global_scale = layer.weight_global_scale + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + output_dtype, x.device, self.group_size) + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + del w_dq, x_dq + return out + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_global_scale", input_global_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer) -> None: + + global_input_scale = layer.input_global_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(global_input_scale, + requires_grad=False) + + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + # required by cutlass kernel; need Parameter, not ModelWeightParameter + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + + if self.cutlass_nvfp4_supported: + layer.alpha = Parameter(layer.input_global_scale * + layer.weight_global_scale, + requires_grad=False) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.cutlass_nvfp4_supported: + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, + layer.weight_scale_swizzled, + 1 / layer.alpha, output_dtype) + if bias is not None: + out = out + bias + return out.view(*output_shape) + return self.run_nvfp4_emulations(x, layer) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 402646498cee..099d8613fc1a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -15,6 +15,7 @@ def is_activation_quantization_format(format: str) -> bool: CompressionFormat.naive_quantized.value, CompressionFormat.int_quantized.value, CompressionFormat.float_quantized.value, + CompressionFormat.nvfp4_pack_quantized.value ] return format in _ACTIVATION_QUANTIZATION_FORMATS diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 6e8e98d544f8..c4ef3ce24c03 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -__all__ = [ - "break_fp4_bytes", - "dequantize_to_dtype", -] +from vllm.scalar_type import scalar_types + +__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], dtype=torch.float32) @@ -60,3 +61,44 @@ def dequantize_to_dtype(tensor_fp4, # scale the tensor out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) return out.to(dtype) + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def ref_nvfp4_quant(x, global_scale, block_size): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // block_size, block_size)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = torch.clamp(scale, max=448, min=-448) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + # both outputs are float32 + return cast_to_fp4(clipped_x), scale.squeeze(-1)