Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
Michael Goin 2025-05-12 20:07:34 -04:00 committed by GitHub
parent 9d7ea9dbbf
commit 307939f299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,13 +2,12 @@
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
dequantize_to_dtype) apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
@ -31,6 +30,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight # Weight
weight = ModelWeightParameter(data=torch.empty( weight = ModelWeightParameter(data=torch.empty(
@ -60,48 +63,30 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
layer.weight_global_scale = Parameter( # Process parameters for marlin repacking
layer.weight_global_scale.max().to(torch.float32),
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# Note: ct stores the inverse of what is expected by the marlin kernel
layer.weight_scale_2 = Parameter(
1 / layer.weight_global_scale.max().to(torch.float32),
requires_grad=False) requires_grad=False)
# Note: a post weight loading step but not required for the emulation del layer.weight_global_scale
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, prepare_fp4_layer_for_marlin(layer)
requires_grad=False)
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp4_marlin_linear(input=x,
w_fp4 = layer.weight_packed.data weight=layer.weight,
w_global_scale = layer.weight_global_scale weight_scale=layer.weight_scale,
w_blockscale = layer.weight_scale_swizzled.data weight_scale_2=layer.weight_scale_2,
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, workspace=layer.workspace,
x.dtype, x.device, self.group_size) size_n=layer.output_size_per_partition,
out = F.linear(x, w_dq) size_k=layer.input_size_per_partition,
del w_dq, w_fp4, w_global_scale, w_blockscale bias=bias)
return out