[Misc] Add assertion and helpful message for marlin24 compressed models (#11388)

This commit is contained in:
Dipika Sikka 2024-12-23 13:23:38 -05:00 committed by GitHub
parent 2e726680b3
commit b866cdbd05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
)
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)