Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
Michael Goin 2025-05-12 20:07:34 -04:00 committed by GitHub
parent 9d7ea9dbbf
commit 307939f299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,13 +2,12 @@
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
@ -31,6 +30,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight
weight = ModelWeightParameter(data=torch.empty(
@ -60,48 +63,30 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer) -> None:
layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32),
# Process parameters for marlin repacking
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# Note: ct stores the inverse of what is expected by the marlin kernel
layer.weight_scale_2 = Parameter(
1 / layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)
# Note: a post weight loading step but not required for the emulation
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
del layer.weight_global_scale
prepare_fp4_layer_for_marlin(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
w_fp4 = layer.weight_packed.data
w_global_scale = layer.weight_global_scale
w_blockscale = layer.weight_scale_swizzled.data
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
x.dtype, x.device, self.group_size)
out = F.linear(x, w_dq)
del w_dq, w_fp4, w_global_scale, w_blockscale
return out
return apply_fp4_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)