From 802748bddbe3759b11cfaa73bd504d6d26bfe408 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 6 Nov 2025 02:33:50 +0800 Subject: [PATCH] [Bugfix] Fix Qwen3-Reranker-8B load (#28117) Signed-off-by: wang.yuqi --- vllm/model_executor/models/adapters.py | 28 +++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 7990024c55d0c..f742090df71fd 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -186,15 +186,21 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): raise NotImplementedError - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + load_lm_head: bool = False, + ): # TODO: Support uninitialized params tracking - # We have deleted this attribute, so don't load it - weights = ( - (name, data) - for name, data in weights - if not name.startswith("lm_head.") - ) + # For most pooling models: We have deleted this attribute, so don't load it. + # For converting an LLM into a seq cls model, we need the lm_head. + if not load_lm_head: + weights = ( + (name, data) + for name, data in weights + if not name.startswith("lm_head.") + ) # If `*ForCausalLM` defines `load_weights` on the inner model # and there are no other inner modules with parameters, @@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax( ) model.lm_head = model.lm_head.tie_weights(embed_tokens) - # Skip ModelForSequenceClassification in MRO to avoid infinite recursion - loaded_weights = type(model).__mro__[1].load_weights(model, weights) + # ModelForPooling is dynamically defined inside the _create_pooling_model_cls + # function, so we need use this hacky method to obtain it. + pooling_model_cls = next( + x for x in type(model).__mro__ if x.__name__ == "ModelForPooling" + ) + loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True) from vllm.transformers_utils.tokenizer import get_tokenizer