mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 17:42:21 +08:00
[Bugfix] Fix RobertaModel loading (#11940)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
a991f7d508
commit
d697dc01b4
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import CLSPool, PoolingType
|
||||||
from vllm.model_executor.models.bert import BertEmbeddingModel
|
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||||
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -92,3 +92,28 @@ def test_roberta_model_loading_with_params(vllm_runner):
|
|||||||
|
|
||||||
# assert output
|
# assert output
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
reason="Xformers backend is not supported on ROCm.")
|
||||||
|
def test_facebook_roberta_model_loading_with_params(vllm_runner):
|
||||||
|
"""
|
||||||
|
Test loading roberta-base model with no lm_head.
|
||||||
|
"""
|
||||||
|
model_name = "FacebookAI/roberta-base"
|
||||||
|
with vllm_runner(model_name=model_name,
|
||||||
|
dtype="float16",
|
||||||
|
max_model_len=MAX_MODEL_LEN) as model:
|
||||||
|
output = model.encode("Write a short story about a robot that"
|
||||||
|
" dreams for the first time.\n")
|
||||||
|
|
||||||
|
model_tokenizer = model.model.llm_engine.tokenizer
|
||||||
|
assert model_tokenizer.tokenizer_id == model_name
|
||||||
|
|
||||||
|
model = model.model.llm_engine.model_executor\
|
||||||
|
.driver_worker.model_runner.model
|
||||||
|
assert not hasattr(model, "lm_head")
|
||||||
|
assert isinstance(model, RobertaEmbeddingModel)
|
||||||
|
assert isinstance(model._pooler, CLSPool)
|
||||||
|
|
||||||
|
assert output
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from ..utils import check_embeddings_close
|
|||||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||||
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
|
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
|
||||||
|
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -20,6 +21,30 @@ from vllm.transformers_utils.config import (
|
|||||||
from .interfaces import SupportsCrossEncoding
|
from .interfaces import SupportsCrossEncoding
|
||||||
|
|
||||||
|
|
||||||
|
def roberta_task_weights_filter(
|
||||||
|
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||||
|
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Separate task-specific weights that are applied on top
|
||||||
|
of the encoder-decoder bert base.
|
||||||
|
To do so, return two generators over the original iterator.
|
||||||
|
Also, remove the "roberta." prefix to make it loadable
|
||||||
|
from vanilla BertModel.
|
||||||
|
"""
|
||||||
|
# Copy of a lazy iterator without in-memory overhead so both
|
||||||
|
# iterators can be iterated upon independently.
|
||||||
|
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||||
|
|
||||||
|
def encoder_decoder_weights():
|
||||||
|
for name, weight in all_weights1:
|
||||||
|
if name.startswith("roberta."):
|
||||||
|
yield (name[len("roberta."):], weight)
|
||||||
|
|
||||||
|
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
||||||
|
if not n.startswith("roberta."))
|
||||||
|
|
||||||
|
|
||||||
class RobertaEmbedding(nn.Module):
|
class RobertaEmbedding(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: RobertaConfig):
|
def __init__(self, config: RobertaConfig):
|
||||||
@ -152,6 +177,18 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
embedding_class=RobertaEmbedding)
|
embedding_class=RobertaEmbedding)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||||
|
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
||||||
|
# For use with models like FacebookAI/roberta-base.
|
||||||
|
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||||
|
loaded = self.model.load_weights(bert_weights)
|
||||||
|
if not len(loaded):
|
||||||
|
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
|
||||||
|
# which use the same architecture, but have no "roberta" prefix.
|
||||||
|
loaded = self.model.load_weights(task_weights)
|
||||||
|
assert len(loaded), "Unable to load RobertaEmbeddingModel"
|
||||||
|
|
||||||
|
|
||||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
"""A model that uses Roberta to provide embedding functionalities.
|
"""A model that uses Roberta to provide embedding functionalities.
|
||||||
@ -181,20 +218,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
self_weights = []
|
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||||
|
self.roberta.load_weights(bert_weights)
|
||||||
def weight_filter():
|
|
||||||
for name, weight in weights:
|
|
||||||
if name.startswith("roberta."):
|
|
||||||
yield (name[len("roberta."):], weight)
|
|
||||||
else:
|
|
||||||
self_weights.append((name, weight))
|
|
||||||
|
|
||||||
self.roberta.load_weights(weight_filter())
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
for name, loaded_weight in self_weights:
|
for name, loaded_weight in task_weights:
|
||||||
if name.startswith("classifier"):
|
if name.startswith("classifier"):
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user