Fix nvfp4 swizzling (#23140)

Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Yi Liu 2025-08-22 00:54:50 +08:00 committed by GitHub
parent a482e4e769
commit 0278f1ac3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 26 deletions

View File

@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32)
@ -137,7 +116,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,

View File

@ -552,8 +552,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
return swizzled.reshape(M_padded, K_padded)
return swizzled.reshape(B, M_padded, K_padded)
def cutlass_fp4_supported() -> bool: