[Bugfix] Fix tensor parallel issue in Qwen3 reranker weight loading (#20682)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
yurhett 2025-07-12 11:52:43 +08:00 committed by GitHub
parent b1235c3e10
commit 11c0198615
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 7 deletions

View File

@ -268,7 +268,8 @@ def mteb_test_rerank_models(hf_runner,
model_info: RerankModelInfo,
vllm_extra_kwargs=None,
hf_model_callback=None,
vllm_mteb_encoder=VllmMtebEncoder):
vllm_mteb_encoder=VllmMtebEncoder,
atol=MTEB_RERANK_TOL):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
@ -301,4 +302,4 @@ def mteb_test_rerank_models(hf_runner,
print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
assert st_main_score == pytest.approx(vllm_main_score, abs=atol)

View File

@ -6,6 +6,7 @@ import pytest
import torch
from tests.conftest import HfRunner
from tests.utils import multi_gpu_test
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
@ -87,3 +88,29 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", RERANK_MODELS)
@multi_gpu_test(num_gpus=2)
def test_rerank_models_mteb_tp(vllm_runner,
model_info: RerankModelInfo) -> None:
assert model_info.architecture == "Qwen3ForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tensor_parallel_size": 2,
}
if model_info.name == "Qwen/Qwen3-Reranker-4B":
vllm_extra_kwargs["max_num_seqs"] = 1
mteb_test_rerank_models(Qwen3RerankerHfRunner,
vllm_runner,
model_info,
vllm_extra_kwargs,
atol=1.2e-2)

View File

@ -322,6 +322,8 @@ def load_weights_using_from_2_way_softmax(
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
from vllm.model_executor.models.utils import AutoWeightsLoader
model_config = model.vllm_config.model_config
@ -329,8 +331,6 @@ def load_weights_using_from_2_way_softmax(
tokens = cast(list[int], tokens)
assert len(tokens) == 2
device = model.score.weight.device
if model.config.tie_word_embeddings:
model.lm_head = model.model.embed_tokens
else:
@ -349,10 +349,13 @@ def load_weights_using_from_2_way_softmax(
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
weight = model.lm_head.weight.data[true_id].to(device).to(
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
weight = model.lm_head.weight.data[[true_id]].to(
torch.float32) - model.lm_head.weight.data[[false_id]].to(
torch.float32)
model.score.weight.data.copy_(weight)
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
del model.lm_head
loaded_weights.add("score.weight")