mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 05:47:03 +08:00
[Bugfix] Fix gptq failure on T4s (#7264)
This commit is contained in:
parent
469b3bc538
commit
311f743831
@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp,
|
||||
min_capability=cls.get_min_capability())
|
||||
has_zp=has_zp)
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
@ -136,8 +136,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return False
|
||||
|
||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size,
|
||||
min_capability=cls.get_min_capability())
|
||||
group_size=group_size)
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(has_zp: bool,
|
||||
min_capability: Optional[int] = None):
|
||||
if min_capability is None:
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
min_capability = major * 10 + minor
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if min_capability < 80:
|
||||
if device_capability < 80:
|
||||
return []
|
||||
|
||||
if has_zp:
|
||||
@ -48,20 +49,20 @@ def _check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if min_capability is None:
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
min_capability = major * 10 + minor
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, min_capability)
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"min_capability = {min_capability}, zp = {has_zp}).")
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Marlin does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
@ -73,9 +74,9 @@ def _check_marlin_supported(
|
||||
def check_marlin_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
min_capability: Optional[int] = None) -> bool:
|
||||
device_capability: Optional[int] = None) -> bool:
|
||||
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
|
||||
min_capability)
|
||||
device_capability)
|
||||
return cond
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user