From 307939f299db4ee695720fdeae3fb4b2dc233353 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 12 May 2025 20:07:34 -0400 Subject: [PATCH] Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000) Signed-off-by: mgoin Signed-off-by: Dipika Co-authored-by: Dipika --- .../schemes/compressed_tensors_w4a16_nvfp4.py | 67 +++++++------------ 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index f192a81645154..caa4fe89c6213 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -2,13 +2,12 @@ from typing import Callable, List, Optional import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - dequantize_to_dtype) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -31,6 +30,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition # Weight weight = ModelWeightParameter(data=torch.empty( @@ -60,48 +63,30 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): layer.register_parameter("weight_scale", weight_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer) -> None: - layer.weight_global_scale = Parameter( - layer.weight_global_scale.max().to(torch.float32), + # Process parameters for marlin repacking + + # Rename weight_packed to weight that marlin expects + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + del layer.weight_packed + # Rename weight_global_scale to weight_scale_2 that marlin expects + # Note: ct stores the inverse of what is expected by the marlin kernel + layer.weight_scale_2 = Parameter( + 1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False) - # Note: a post weight loading step but not required for the emulation - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + del layer.weight_global_scale + + prepare_fp4_layer_for_marlin(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - w_fp4 = layer.weight_packed.data - w_global_scale = layer.weight_global_scale - w_blockscale = layer.weight_scale_swizzled.data - w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - x.dtype, x.device, self.group_size) - out = F.linear(x, w_dq) - del w_dq, w_fp4, w_global_scale, w_blockscale - return out + return apply_fp4_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias)