diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 70f716f95e89..c968a68f1a8e 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -13,9 +13,9 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform @@ -648,3 +648,23 @@ def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +def test_compressed_tensors_nvfp4a16(vllm_runner): + # run weight only example + model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4" + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4) + assert qkv_proj.scheme.group_size == 16 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0585c09bd84b..a001a8582dda 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -23,9 +23,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, - CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -216,6 +217,21 @@ class CompressedTensorsConfig(QuantizationConfig): else: return False + def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, + input_quant: BaseModel): + + is_weight_only = weight_quant is not None and input_quant is None + is_group_quant = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_symmetric = weight_quant.symmetric + + is_group_size_16 = weight_quant.group_size == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT + is_4_bits = weight_quant.num_bits == 4 + + return (is_weight_only and is_group_quant and is_float_type + and is_4_bits and is_group_size_16 and is_symmetric) + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 @@ -315,6 +331,9 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant: BaseModel) -> "CompressedTensorsScheme": # Detect If Mixed Precision + if self._is_fp4a16_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A16Fp4() + if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index b26c74f2484b..79bf5c108ac2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 @@ -16,5 +17,5 @@ __all__ = [ "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24" + "CompressedTensors24", "CompressedTensorsW4A16Fp4" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py new file mode 100644 index 000000000000..f192a8164515 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + dequantize_to_dtype) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW4A16Fp4"] + + +class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): + + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # Weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer) -> None: + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + # Note: a post weight loading step but not required for the emulation + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + w_fp4 = layer.weight_packed.data + w_global_scale = layer.weight_global_scale + w_blockscale = layer.weight_scale_swizzled.data + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, + x.dtype, x.device, self.group_size) + out = F.linear(x, w_dq) + del w_dq, w_fp4, w_global_scale, w_blockscale + return out diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py new file mode 100644 index 000000000000..f292208311e2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +__all__ = [ + "break_fp4_bytes", + "dequantize_to_dtype", +] + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype)