mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 19:00:56 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
259 lines
10 KiB
Python
259 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import itertools
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import RobertaConfig
|
|
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
|
from vllm.model_executor.models.utils import maybe_prefix
|
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
|
from vllm.transformers_utils.config import (
|
|
get_cross_encoder_activation_function)
|
|
|
|
from .interfaces import SupportsCrossEncoding
|
|
|
|
|
|
def roberta_task_weights_filter(
|
|
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
|
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
|
torch.Tensor]]]:
|
|
"""
|
|
Separate task-specific weights that are applied on top
|
|
of the encoder-decoder bert base.
|
|
To do so, return two generators over the original iterator.
|
|
Also, remove the "roberta." prefix to make it loadable
|
|
from vanilla BertModel.
|
|
"""
|
|
# Copy of a lazy iterator without in-memory overhead so both
|
|
# iterators can be iterated upon independently.
|
|
all_weights1, all_weights2 = itertools.tee(all_weights)
|
|
|
|
def encoder_decoder_weights():
|
|
for name, weight in all_weights1:
|
|
if name.startswith("roberta."):
|
|
yield (name[len("roberta."):], weight)
|
|
|
|
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
|
|
if not n.startswith("roberta."))
|
|
|
|
|
|
class RobertaEmbedding(nn.Module):
|
|
|
|
def __init__(self, config: RobertaConfig):
|
|
super().__init__()
|
|
self.size = config.hidden_size
|
|
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
|
config.hidden_size)
|
|
self.padding_idx = config.pad_token_id
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
|
|
config.hidden_size,
|
|
padding_idx=self.padding_idx)
|
|
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
|
config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
self.position_ids = nn.Parameter(
|
|
torch.empty((1, config.max_position_embeddings)), )
|
|
|
|
self.position_embedding_type = config.position_embedding_type
|
|
if self.position_embedding_type != "absolute":
|
|
raise ValueError("Only 'absolute' position_embedding_type" +
|
|
" is supported")
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
input_shape = input_ids.size()
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
# Replace position ids because in RoBERTa models
|
|
# they have to start at padding_idx + 1 and ignore
|
|
# existing padding tokens
|
|
# References:
|
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
|
pos_list = []
|
|
token_list = []
|
|
offset = 0
|
|
for seq_len in seq_lens:
|
|
pos_list.append(position_ids[offset:offset + seq_len])
|
|
token_list.append(input_ids[offset:offset + seq_len])
|
|
offset += seq_len
|
|
|
|
new_pos_list = []
|
|
for positions, tokens in zip(pos_list, token_list):
|
|
# Verify assumption that incoming position are
|
|
# always a sequence from 0 to N.
|
|
expected_pos = torch.arange(positions.size()[0],
|
|
dtype=torch.long,
|
|
device=inputs_embeds.device)
|
|
assert torch.equal(positions, expected_pos)
|
|
new_pos_list.append(
|
|
create_position_ids_from_input_ids(tokens, self.padding_idx))
|
|
position_ids = torch.cat(new_pos_list)
|
|
|
|
# Position embeddings.
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(input_shape,
|
|
dtype=torch.long,
|
|
device=inputs_embeds.device)
|
|
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
return embeddings
|
|
|
|
|
|
# Adapted from transformers
|
|
def create_position_ids_from_input_ids(input_ids,
|
|
padding_idx,
|
|
past_key_values_length=0):
|
|
"""
|
|
Replace non-padding symbols with their position numbers.
|
|
Position numbers begin at padding_idx+1. Padding symbols
|
|
are ignored. This is modified from fairseq's `utils.make_positions`.
|
|
|
|
Args:
|
|
x: torch.Tensor x:
|
|
|
|
Returns: torch.Tensor
|
|
"""
|
|
# The series of casts and type-conversions here are carefully
|
|
# balanced to both work with ONNX export and XLA.
|
|
mask = input_ids.ne(padding_idx).int()
|
|
|
|
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
|
past_key_values_length) * mask
|
|
|
|
return incremental_indices.long() + padding_idx
|
|
|
|
|
|
# Adapted from transformers
|
|
class RobertaClassificationHead(nn.Module):
|
|
"""Head for sentence-level classification tasks."""
|
|
|
|
def __init__(self, config: RobertaConfig):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
def forward(self, features, **kwargs):
|
|
x = features[0, :] # take <s> token (equiv. to [CLS])
|
|
x = self.dense(x)
|
|
x = torch.tanh(x)
|
|
x = self.out_proj(x)
|
|
return x
|
|
|
|
|
|
class RobertaEmbeddingModel(BertEmbeddingModel):
|
|
"""A model that uses Roberta to provide embedding functionalities.
|
|
|
|
This class encapsulates the BertModel and provides an interface for
|
|
embedding operations and customized pooling functions.
|
|
|
|
Attributes:
|
|
model: An instance of BertModel used for forward operations.
|
|
_pooler: An instance of Pooler used for pooling operations.
|
|
"""
|
|
|
|
def _build_model(self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "") -> BertModel:
|
|
return BertModel(vllm_config=vllm_config,
|
|
prefix=prefix,
|
|
embedding_class=RobertaEmbedding)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
weights = self.hf_to_vllm_mapper.apply(weights)
|
|
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
|
# For use with models like FacebookAI/roberta-base.
|
|
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
|
loaded = self.model.load_weights(bert_weights)
|
|
if not len(loaded):
|
|
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
|
|
# which use the same architecture, but have no "roberta" prefix.
|
|
loaded = self.model.load_weights(task_weights)
|
|
assert len(loaded), "Unable to load RobertaEmbeddingModel"
|
|
|
|
|
|
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|
"""A model that uses Roberta to provide embedding functionalities.
|
|
|
|
This class encapsulates the BertModel and provides an interface for
|
|
embedding operations and customized pooling functions.
|
|
|
|
Attributes:
|
|
roberta: An instance of BertModel used for forward operations.
|
|
_pooler: An instance of Pooler used for pooling operations.
|
|
"""
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
self.default_activation_function = \
|
|
get_cross_encoder_activation_function(config)
|
|
|
|
self.num_labels = config.num_labels
|
|
self.roberta = BertModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "bert"),
|
|
embedding_class=RobertaEmbedding,
|
|
add_pooling_layer=False)
|
|
self.classifier = RobertaClassificationHead(config)
|
|
self._pooler = CrossEncodingPooler(config, self.classifier)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
|
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
|
self.roberta.load_weights(bert_weights)
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
for name, loaded_weight in task_weights:
|
|
if name.startswith("classifier"):
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
def pooler(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> Optional[PoolerOutput]:
|
|
return self._pooler(hidden_states, pooling_metadata)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return self.roberta(input_ids=input_ids,
|
|
position_ids=positions,
|
|
kv_caches=kv_caches,
|
|
inputs_embeds=inputs_embeds,
|
|
intermediate_tensors=intermediate_tensors,
|
|
attn_metadata=attn_metadata,
|
|
token_type_ids=token_type_ids)
|