[Bugfix] Fix Qwen3-Reranker-8B load (#28117)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-11-06 02:33:50 +08:00 committed by GitHub
parent faedbb4d4f
commit 802748bddb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -186,15 +186,21 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
raise NotImplementedError
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
load_lm_head: bool = False,
):
# TODO: Support uninitialized params tracking
# We have deleted this attribute, so don't load it
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
)
# For most pooling models: We have deleted this attribute, so don't load it.
# For converting an LLM into a seq cls model, we need the lm_head.
if not load_lm_head:
weights = (
(name, data)
for name, data in weights
if not name.startswith("lm_head.")
)
# If `*ForCausalLM` defines `load_weights` on the inner model
# and there are no other inner modules with parameters,
@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax(
)
model.lm_head = model.lm_head.tie_weights(embed_tokens)
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
# function, so we need use this hacky method to obtain it.
pooling_model_cls = next(
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
)
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
from vllm.transformers_utils.tokenizer import get_tokenizer