From 2c2b140ae8c60dc0c38e4d37274fc7106a72c21b Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Tue, 26 Aug 2025 21:23:23 -0400 Subject: [PATCH] [quantization] use channel scales for w4a8 + misc fixes (#23570) Signed-off-by: czhu-cohere --- tests/quantization/test_compressed_tensors.py | 44 +++++++++++++++++-- .../schemes/compressed_tensors_w4a8_fp8.py | 13 +++++- .../kernels/mixed_precision/MPLinearKernel.py | 1 + .../kernels/mixed_precision/cutlass.py | 19 ++++---- 4 files changed, 63 insertions(+), 14 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 296743dbfa04..b9774b7ee263 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -14,10 +14,10 @@ from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -683,3 +683,39 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), + reason="W4A8 FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize("args", [ + ("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8) +]) +def test_compressed_tensors_w4a8_fp8(vllm_runner, args): + model, scheme = args + with vllm_runner(model, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, scheme) + + assert proj.weight_packed.dtype is torch.int32 + assert proj.weight_scale.dtype is torch.float8_e4m3fn + assert proj.weight_chan_scale.dtype is torch.float32 + assert proj.scheme.group_size == 128 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py index f6cc49c2316b..3d9827058803 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -79,7 +79,8 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): act_type=torch.float8_e4m3fn, # always use fp8(e4m3) group_size=self.group_size, zero_points=not self.symmetric, - has_g_idx=self.has_g_idx + has_g_idx=self.has_g_idx, + out_type=params_dtype ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) @@ -122,7 +123,7 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): torch.empty( output_size_per_partition, scales_and_zp_size, - dtype=params_dtype, + dtype=torch.float8_e4m3fn, ) } @@ -140,9 +141,17 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): dtype=torch.int64), weight_loader=weight_loader) + # per-channel scales + weight_chan_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("weight_chan_scale", weight_chan_scale) self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 07ecc096231a..1280f5f1eadf 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -20,6 +20,7 @@ class MPLinearLayerConfig: group_size: int zero_points: bool has_g_idx: bool + out_type: Optional[torch.dtype] = None class MPLinearKernel(ABC): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py index f1d49693fc01..9e23c0dd3595 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -60,13 +60,17 @@ class CutlassW4A8LinearKernel(MPLinearKernel): if in_features % 128 or out_features % 128: return False, "K and N must be divisible by 128, got "\ f"{c.partition_weight_shape}" + + if c.out_type != torch.bfloat16: + return False, "Only bfloat16 output type currently supported"\ + f"got {c.out_type=}" + return True, None # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): - c = self.config # TODO(czhu): optimize speed/mem usage def transform_w_q(x): @@ -86,19 +90,15 @@ class CutlassW4A8LinearKernel(MPLinearKernel): # Encode/reorder weights and pack scales self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - - # TODO(czhu): support loading channel scales - self.w_ch_s = torch.ones((c.partition_weight_shape[1], ), - dtype=torch.float32, - device='cuda') + self._transform_param(layer, "weight_chan_scale", lambda x: x) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert bias is None, "bias not supported by CUTLASS W4A8" c = self.config w_q, w_s, _, _ = self._get_weight_params(layer) + w_ch_s = layer.weight_chan_scale x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) @@ -109,6 +109,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel): b_group_scales=w_s, b_group_size=c.group_size, a_token_scales=act_scales, - b_channel_scales=self.w_ch_s) + b_channel_scales=w_ch_s) + + if bias is not None: + output.add_(bias) # In-place add return output.reshape(out_shape)