[refactor] CTConfig methods to static/class methods (#28870)

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
HDCharles 2025-11-26 12:21:58 -05:00 committed by GitHub
parent 0b0aa874e8
commit e603129505
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -266,8 +266,9 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
return [] return []
@staticmethod
def _check_scheme_supported( def _check_scheme_supported(
self, min_capability: int, error: bool = True, match_exact: bool = False min_capability: int, error: bool = True, match_exact: bool = False
) -> bool: ) -> bool:
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
@ -293,9 +294,8 @@ class CompressedTensorsConfig(QuantizationConfig):
else: else:
return False return False
def _is_fp4a4_nvfp4( @staticmethod
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs def _is_fp4a4_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
):
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
return False return False
@ -322,9 +322,8 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric and is_symmetric
) )
def _is_fp4a16_nvfp4( @staticmethod
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs def _is_fp4a16_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
):
is_weight_only = weight_quant is not None and input_quant is None is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = ( is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
@ -344,8 +343,9 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric and is_symmetric
) )
@staticmethod
def _is_static_tensor_w8a8( def _is_static_tensor_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = ( weight_strategy = (
@ -362,8 +362,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static return is_8_bits and is_tensor and weight_quant.symmetric and is_static
@staticmethod
def _is_dynamic_token_w8a8( def _is_dynamic_token_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = ( weight_strategy = (
@ -379,8 +380,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
@staticmethod
def _is_dynamic_token_w4a8_int( def _is_dynamic_token_w4a8_int(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
is_weight_4_bits = weight_quant.num_bits == 4 is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8 is_activation_8_bits = input_quant.num_bits == 8
@ -403,8 +405,9 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_dynamic and is_dynamic
) )
@staticmethod
def _is_fp8_w8a8( def _is_fp8_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
# Confirm weights and activations quantized. # Confirm weights and activations quantized.
if weight_quant is None or input_quant is None: if weight_quant is None or input_quant is None:
@ -439,8 +442,9 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
return is_symmetric_activation and is_per_tensor_activation return is_symmetric_activation and is_per_tensor_activation
@staticmethod
def _is_fp8_w4a8( def _is_fp8_w4a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
if not weight_quant or not input_quant: if not weight_quant or not input_quant:
return False return False
@ -462,29 +466,33 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_dynamic and is_dynamic
) )
@classmethod
def _is_fp8_w4a8_sm90( def _is_fp8_w4a8_sm90(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs cls, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
return self._check_scheme_supported( return cls._check_scheme_supported(
90, error=False, match_exact=True 90, error=False, match_exact=True
) and self._is_fp8_w4a8(weight_quant, input_quant) ) and cls._is_fp8_w4a8(weight_quant, input_quant)
@classmethod
def _is_fp8_w8a8_sm90( def _is_fp8_w8a8_sm90(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs cls, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
return self._check_scheme_supported( return cls._check_scheme_supported(
90, error=False, match_exact=True 90, error=False, match_exact=True
) and self._is_fp8_w8a8(weight_quant, input_quant) ) and cls._is_fp8_w8a8(weight_quant, input_quant)
@classmethod
def _is_fp8_w8a8_sm100( def _is_fp8_w8a8_sm100(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs cls, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
return self._check_scheme_supported( return cls._check_scheme_supported(
100, error=False, match_exact=True 100, error=False, match_exact=True
) and self._is_fp8_w8a8(weight_quant, input_quant) ) and cls._is_fp8_w8a8(weight_quant, input_quant)
@staticmethod
def _is_fp8_w8a16( def _is_fp8_w8a16(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
# Confirm weights quantized. # Confirm weights quantized.
if weight_quant is None: if weight_quant is None:
@ -508,8 +516,9 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_tensor_or_channel_or_block_weight and is_tensor_or_channel_or_block_weight
) )
@staticmethod
def _is_wNa16_group_channel( def _is_wNa16_group_channel(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool: ) -> bool:
input_quant_none = input_quant is None input_quant_none = input_quant is None
is_channel_group = ( is_channel_group = (