mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 20:15:21 +08:00
[Bugfix] Respect modules_to_not_convert within awq_marlin (#9895)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
2094062b4e
commit
8f0a9ca890
@ -9,7 +9,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
8: scalar_types.uint8,
|
8: scalar_types.uint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
|
def __init__(self,
|
||||||
lm_head_quantized: bool) -> None:
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
zero_point: bool,
|
||||||
|
lm_head_quantized: bool,
|
||||||
|
modules_to_not_convert: Optional[List[str]] = None) -> None:
|
||||||
self.pack_factor = 32 // weight_bits # packed into int32
|
self.pack_factor = 32 // weight_bits # packed into int32
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.has_zp = has_zp
|
self.zero_point = zero_point
|
||||||
self.lm_head_quantized = lm_head_quantized
|
self.lm_head_quantized = lm_head_quantized
|
||||||
self.weight_bits = weight_bits
|
self.weight_bits = weight_bits
|
||||||
|
self.modules_to_not_convert = modules_to_not_convert or []
|
||||||
|
|
||||||
if self.weight_bits not in self.TYPE_MAP:
|
if self.weight_bits not in self.TYPE_MAP:
|
||||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||||
@ -52,13 +59,14 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
verify_marlin_supported(self.quant_type,
|
verify_marlin_supported(self.quant_type,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
has_zp=self.has_zp)
|
has_zp=self.zero_point)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
|
||||||
f"group_size={self.group_size}, "
|
f"group_size={self.group_size}, "
|
||||||
f"has_zp={self.has_zp}, "
|
f"zero_point={self.zero_point}, "
|
||||||
f"lm_head_quantized={self.lm_head_quantized})")
|
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||||
|
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
@ -80,10 +88,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
|
||||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||||
default=False)
|
default=False)
|
||||||
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
|
modules_to_not_convert = cls.get_from_keys_or(
|
||||||
|
config, ["modules_to_not_convert"], None)
|
||||||
|
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
|
||||||
|
modules_to_not_convert)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
@ -109,6 +120,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
if (isinstance(layer, LinearBase) or
|
if (isinstance(layer, LinearBase) or
|
||||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||||
|
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
return AWQMarlinLinearMethod(self)
|
return AWQMarlinLinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return AWQMoEMethod(self)
|
return AWQMoEMethod(self)
|
||||||
@ -123,7 +136,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
quant_method = quant_config.get("quant_method", "").lower()
|
quant_method = quant_config.get("quant_method", "").lower()
|
||||||
num_bits = quant_config.get("bits")
|
num_bits = quant_config.get("bits")
|
||||||
group_size = quant_config.get("group_size")
|
group_size = quant_config.get("group_size")
|
||||||
has_zp = quant_config.get("zero_point")
|
zero_point = quant_config.get("zero_point")
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
return False
|
return False
|
||||||
@ -132,7 +145,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# If we cannot find the info needed in the config, cannot convert.
|
# If we cannot find the info needed in the config, cannot convert.
|
||||||
if (num_bits is None or group_size is None or has_zp is None):
|
if (num_bits is None or group_size is None or zero_point is None):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if num_bits not in cls.TYPE_MAP:
|
if num_bits not in cls.TYPE_MAP:
|
||||||
@ -140,7 +153,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
has_zp=has_zp)
|
has_zp=zero_point)
|
||||||
|
|
||||||
|
|
||||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user