diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 63bfe565b1211..49d76bbeaa3a1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 run_nvfp4_emulations) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + swizzle_blockscale) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -83,29 +85,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): weight_loader=weight_loader) layer.register_parameter("input_global_scale", input_global_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer) -> None: global_input_scale = layer.input_global_scale.max().to(torch.float32) @@ -137,7 +116,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False) else: - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight_packed = Parameter(layer.weight_packed.data, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 3cfaca6230b12..97e5922ebd55f 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -552,8 +552,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda() if scale_ndim == 2: - return swizzled.reshape(M, K) - return swizzled.reshape(B, M, K) + return swizzled.reshape(M_padded, K_padded) + return swizzled.reshape(B, M_padded, K_padded) def cutlass_fp4_supported() -> bool: