mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 12:07:21 +08:00
[Bugfix] Fix Qwen3-Reranker-8B load (#28117)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
faedbb4d4f
commit
802748bddb
@ -186,15 +186,21 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
load_lm_head: bool = False,
|
||||
):
|
||||
# TODO: Support uninitialized params tracking
|
||||
|
||||
# We have deleted this attribute, so don't load it
|
||||
weights = (
|
||||
(name, data)
|
||||
for name, data in weights
|
||||
if not name.startswith("lm_head.")
|
||||
)
|
||||
# For most pooling models: We have deleted this attribute, so don't load it.
|
||||
# For converting an LLM into a seq cls model, we need the lm_head.
|
||||
if not load_lm_head:
|
||||
weights = (
|
||||
(name, data)
|
||||
for name, data in weights
|
||||
if not name.startswith("lm_head.")
|
||||
)
|
||||
|
||||
# If `*ForCausalLM` defines `load_weights` on the inner model
|
||||
# and there are no other inner modules with parameters,
|
||||
@ -431,8 +437,12 @@ def load_weights_using_from_2_way_softmax(
|
||||
)
|
||||
model.lm_head = model.lm_head.tie_weights(embed_tokens)
|
||||
|
||||
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
|
||||
loaded_weights = type(model).__mro__[1].load_weights(model, weights)
|
||||
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
|
||||
# function, so we need use this hacky method to obtain it.
|
||||
pooling_model_cls = next(
|
||||
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
|
||||
)
|
||||
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user