[Bugfix] Fix gptq failure on T4s (#7264)

This commit is contained in:
Lucas Wilkinson 2024-08-07 16:05:37 -04:00 committed by GitHub
parent 469b3bc538
commit 311f743831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 15 deletions

View File

@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig):
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
has_zp=has_zp)
class AWQMarlinLinearMethod(LinearMethodBase):

View File

@ -136,8 +136,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size,
min_capability=cls.get_min_capability())
group_size=group_size)
class GPTQMarlinLinearMethod(LinearMethodBase):

View File

@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor
if min_capability < 80:
if device_capability < 80:
return []
if has_zp:
@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if min_capability is None:
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
@ -73,9 +74,9 @@ def _check_marlin_supported(
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
device_capability)
return cond