[ Misc ] Improve Min Capability Checking in compressed-tensors (#6522)

This commit is contained in:
Robert Shaw 2024-07-18 10:39:12 -04:00 committed by GitHub
parent 4634c8728b
commit 58ca663224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 41 additions and 8 deletions

View File

@ -37,7 +37,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 75
return 70
def get_name(self) -> str:
return "compressed_tensors"
@ -85,13 +85,14 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]:
return []
def _check_gptq_and_marlin_can_run(self):
def _check_scheme_supported(self, min_capability: int):
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
"the current GPU. Minimum capability: 80. ",
f"Current capability: {capability}.")
if capability < min_capability:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
@ -171,7 +172,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(
@ -222,10 +222,16 @@ class CompressedTensorsConfig(QuantizationConfig):
raise ValueError(
f"Could not find quantization details for {layer}.")
return self._get_schema(
scheme = self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
class CompressedTensorsLinearMethod(LinearMethodBase):

View File

@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
"""
@abstractmethod
def get_min_capability(self) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""

View File

@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""
def get_min_capability(self) -> int:
# volta and up
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

View File

@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise ValueError(
"group_size must be given when using strategy group")
def get_min_capability(self) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

View File

@ -33,6 +33,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
"Consider quantizing with per tensor scales or upgrading "
"to Hopper.")
def get_min_capability(self) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),

View File

@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
def get_min_capability(self) -> int:
# turing and up
return 75
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.

View File

@ -42,6 +42,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size=self.group_size,
is_sym=True)
def get_min_capability(self) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,