[BUGFIX] GPTQ quantization compatibility for Qwen3 MOE models (AutoGPTQ and AutoRound-GPTQ) (#23994)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
JartX 2025-09-01 05:33:40 +02:00 committed by GitHub
parent 14b4326b94
commit 183a70967a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 4 deletions

View File

@ -37,6 +37,7 @@ class GPTQConfig(QuantizationConfig):
desc_act: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
autoround_version: str = "",
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
@ -74,6 +75,9 @@ class GPTQConfig(QuantizationConfig):
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
@ -108,8 +112,10 @@ class GPTQConfig(QuantizationConfig):
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic)
dynamic, autoround_version)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str

View File

@ -119,6 +119,9 @@ class GPTQMarlinConfig(QuantizationConfig):
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
# used to identify GPTQ model quantized by autoround
self.autoround_version = full_config.get("autoround_version", "")
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "

View File

@ -159,9 +159,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
# seems to avoid gate quantization while AutoRound does.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
if isinstance(
quant_config,
(GPTQConfig,
GPTQMarlinConfig)) and not quant_config.autoround_version:
return None
return quant_config