From b866cdbd05b13e0c0ab349efc6fca834fbe21760 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 23 Dec 2024 13:23:38 -0500 Subject: [PATCH] [Misc] Add assertion and helpful message for marlin24 compressed models (#11388) --- .../compressed_tensors/schemes/compressed_tensors_w4a16_24.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 9ad61a64e406c..61d1c911cd1ad 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + assert params_dtype == torch.float16, ( + "float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501 + ) + pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes)