mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:05:37 +08:00
[Bugfix] Fix tensor parallel issue in Qwen3 reranker weight loading (#20682)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
b1235c3e10
commit
11c0198615
@ -268,7 +268,8 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
model_info: RerankModelInfo,
|
model_info: RerankModelInfo,
|
||||||
vllm_extra_kwargs=None,
|
vllm_extra_kwargs=None,
|
||||||
hf_model_callback=None,
|
hf_model_callback=None,
|
||||||
vllm_mteb_encoder=VllmMtebEncoder):
|
vllm_mteb_encoder=VllmMtebEncoder,
|
||||||
|
atol=MTEB_RERANK_TOL):
|
||||||
if not model_info.enable_test:
|
if not model_info.enable_test:
|
||||||
# A model family has many models with the same architecture,
|
# A model family has many models with the same architecture,
|
||||||
# and we don't need to test each one.
|
# and we don't need to test each one.
|
||||||
@ -301,4 +302,4 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||||
print("Difference:", st_main_score - vllm_main_score)
|
print("Difference:", st_main_score - vllm_main_score)
|
||||||
|
|
||||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
|
assert st_main_score == pytest.approx(vllm_main_score, abs=atol)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.conftest import HfRunner
|
from tests.conftest import HfRunner
|
||||||
|
from tests.utils import multi_gpu_test
|
||||||
|
|
||||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||||
|
|
||||||
@ -87,3 +88,29 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
|||||||
|
|
||||||
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
|
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
|
||||||
vllm_extra_kwargs)
|
vllm_extra_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_rerank_models_mteb_tp(vllm_runner,
|
||||||
|
model_info: RerankModelInfo) -> None:
|
||||||
|
|
||||||
|
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
||||||
|
|
||||||
|
vllm_extra_kwargs: dict[str, Any] = {
|
||||||
|
"hf_overrides": {
|
||||||
|
"architectures": ["Qwen3ForSequenceClassification"],
|
||||||
|
"classifier_from_token": ["no", "yes"],
|
||||||
|
"is_original_qwen3_reranker": True,
|
||||||
|
},
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_info.name == "Qwen/Qwen3-Reranker-4B":
|
||||||
|
vllm_extra_kwargs["max_num_seqs"] = 1
|
||||||
|
|
||||||
|
mteb_test_rerank_models(Qwen3RerankerHfRunner,
|
||||||
|
vllm_runner,
|
||||||
|
model_info,
|
||||||
|
vllm_extra_kwargs,
|
||||||
|
atol=1.2e-2)
|
||||||
|
|||||||
@ -322,6 +322,8 @@ def load_weights_using_from_2_way_softmax(
|
|||||||
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead)
|
ParallelLMHead)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader)
|
||||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||||
|
|
||||||
model_config = model.vllm_config.model_config
|
model_config = model.vllm_config.model_config
|
||||||
@ -329,8 +331,6 @@ def load_weights_using_from_2_way_softmax(
|
|||||||
tokens = cast(list[int], tokens)
|
tokens = cast(list[int], tokens)
|
||||||
assert len(tokens) == 2
|
assert len(tokens) == 2
|
||||||
|
|
||||||
device = model.score.weight.device
|
|
||||||
|
|
||||||
if model.config.tie_word_embeddings:
|
if model.config.tie_word_embeddings:
|
||||||
model.lm_head = model.model.embed_tokens
|
model.lm_head = model.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
@ -349,10 +349,13 @@ def load_weights_using_from_2_way_softmax(
|
|||||||
|
|
||||||
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
|
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||||
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
|
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||||
weight = model.lm_head.weight.data[true_id].to(device).to(
|
weight = model.lm_head.weight.data[[true_id]].to(
|
||||||
torch.float32) - model.lm_head.weight.data[false_id].to(device).to(
|
torch.float32) - model.lm_head.weight.data[[false_id]].to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
model.score.weight.data.copy_(weight)
|
|
||||||
|
param = model.score.weight
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, weight)
|
||||||
|
|
||||||
del model.lm_head
|
del model.lm_head
|
||||||
loaded_weights.add("score.weight")
|
loaded_weights.add("score.weight")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user