From c7ae7edb3326be48c65a3cd7151385545c877dda Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 27 Sep 2025 13:18:20 +0100 Subject: [PATCH] Fix GPTQ model loading in Transformers backend (#25770) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py Signed-off-by: yewentao256 --- tests/models/test_transformers.py | 10 +++++++--- vllm/model_executor/models/transformers.py | 22 +++++++++++++++++----- vllm/model_executor/models/utils.py | 7 +++++-- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 1817d4aeee9f9..e4b5e7c244539 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -100,10 +100,9 @@ def test_distributed( kwargs_test=kwargs) -@pytest.mark.skipif( - current_platform.is_rocm(), - reason="bitsandbytes quantization is currently not supported in rocm.") @pytest.mark.parametrize("model, quantization_kwargs", [ + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}), + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}), ( "meta-llama/Llama-3.2-1B-Instruct", { @@ -121,6 +120,11 @@ def test_quantization( max_tokens: int, num_logprobs: int, ) -> None: + if (current_platform.is_rocm() + and quantization_kwargs.get("quantization", "") == "bitsandbytes"): + pytest.skip( + "bitsandbytes quantization is currently not supported in rocm.") + with vllm_runner( model, model_impl="auto", enforce_eager=True, **quantization_kwargs) as vllm_model: # type: ignore[arg-type] diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 3d7b06633f342..7cfb639f675d5 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -447,7 +447,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self.device_config: DeviceConfig = vllm_config.device_config self.model_config: ModelConfig = vllm_config.model_config self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: QuantizationConfig = vllm_config.quant_config + self.quant_config: Optional[ + QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size @@ -456,7 +457,18 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] + """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] + """Skip loading weights whose qualname contains these substrings.""" + self.ignore_unexpected_prefixes: list[str] = [] + """Ignore unexpected weights whose qualname starts with these prefixes. + """ + self.ignore_unexpected_suffixes: list[str] = [] + """Ignore unexpected weights whose qualname ends with these suffixes.""" + + # Skip loading extra bias for GPTQ models. + if self.quant_config and "gptq" in self.quant_config.get_name(): + self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` @@ -563,9 +575,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}") - def _tensor_parallel(module: nn.Module, - prefix: str = "", - tp_plan=None): + def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None): tp_plan = tp_plan or {} # If the current module is a PreTrainedModel, set the tp_plan for @@ -597,7 +607,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): prefix=qual_name, tp_plan=tp_plan) - _tensor_parallel(self.model) + _tensor_parallel(self.model, prefix="model") def create_attention_instances( self, @@ -696,6 +706,8 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): self, skip_prefixes=self.skip_prefixes, skip_substrs=self.skip_substrs, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7b3f20c6b28a1..bb6a0bd022021 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -109,6 +109,7 @@ class AutoWeightsLoader: skip_prefixes: Optional[list[str]] = None, skip_substrs: Optional[list[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None, + ignore_unexpected_suffixes: Optional[list[str]] = None, ) -> None: super().__init__() @@ -116,6 +117,7 @@ class AutoWeightsLoader: self.skip_prefixes = skip_prefixes or [] self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or [] # update default skip_substrs self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS @@ -149,8 +151,9 @@ class AutoWeightsLoader: or any(substr in qualname for substr in self.skip_substrs)) def _can_ignore_unexpected(self, qualname: str) -> bool: - return any( - qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes) + return any(iup) or any(ius) def _load_param( self,