mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:55:36 +08:00
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
452 lines
16 KiB
Python
452 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Iterable, Set
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import ModernBertConfig
|
|
from transformers.activations import ACT2FN
|
|
|
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
|
from vllm.model_executor.layers.pooler import (
|
|
ClassifierPooler,
|
|
DispatchPooler,
|
|
Pooler,
|
|
PoolingMethod,
|
|
PoolingParamsUpdate,
|
|
PoolingType,
|
|
)
|
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.tasks import PoolingTask
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
|
|
from .interfaces import SupportsCrossEncoding
|
|
from .interfaces_base import attn_type, default_pooling_type
|
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
|
|
|
|
|
class ModernBertEmbeddings(nn.Module):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.tok_embeddings = VocabParallelEmbedding(
|
|
config.vocab_size, config.hidden_size
|
|
)
|
|
eps = (
|
|
getattr(config, "norm_eps", None)
|
|
or getattr(config, "layer_norm_eps", None)
|
|
or 1e-5
|
|
)
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.tok_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if inputs_embeds is not None:
|
|
return self.norm(inputs_embeds)
|
|
else:
|
|
inputs_embeds = self.tok_embeddings(input_ids)
|
|
embeddings = self.norm(inputs_embeds)
|
|
return embeddings
|
|
|
|
|
|
class ModernBertRotaryEmbedding(RotaryEmbedding):
|
|
def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float):
|
|
super().__init__(
|
|
head_size=head_size,
|
|
rotary_dim=dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=base,
|
|
is_neox_style=True,
|
|
dtype=torch.float16,
|
|
)
|
|
self.config = config
|
|
|
|
|
|
class ModernBertAttention(nn.Module):
|
|
def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.layer_id = layer_id
|
|
self.deterministic_flash_attn = config.deterministic_flash_attn
|
|
self.num_heads = config.num_attention_heads
|
|
assert self.num_heads % tp_size == 0
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
self.all_head_size = self.head_dim * self.num_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
self.Wqkv = QKVParallelLinear(
|
|
config.hidden_size,
|
|
self.head_dim,
|
|
self.num_heads,
|
|
bias=config.attention_bias,
|
|
)
|
|
|
|
sliding_window = None
|
|
if layer_id % config.global_attn_every_n_layers != 0:
|
|
sliding_window = config.local_attention // 2
|
|
rope_theta = (
|
|
config.local_rope_theta
|
|
if config.local_rope_theta is not None
|
|
else config.global_rope_theta
|
|
)
|
|
else:
|
|
rope_theta = config.global_rope_theta
|
|
|
|
self.rotary_emb = ModernBertRotaryEmbedding(
|
|
config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta
|
|
)
|
|
self.attn = EncoderOnlyAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
prefix=f"{layer_id}.attn",
|
|
per_layer_sliding_window=sliding_window,
|
|
)
|
|
self.Wo = RowParallelLinear(
|
|
config.hidden_size, config.hidden_size, bias=config.attention_bias
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.Wqkv(hidden_states)
|
|
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
|
|
q, k = self.rotary_emb(position_ids, q, k)
|
|
attn_outputs = self.attn(q, k, v)
|
|
hidden_states = attn_outputs
|
|
hidden_states, _ = self.Wo(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class ModernBertMLP(nn.Module):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.Wi = nn.Linear(
|
|
config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias
|
|
)
|
|
self.act = nn.GELU()
|
|
self.Wo = RowParallelLinear(
|
|
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
|
|
return self.Wo(self.act(input) * gate)[0]
|
|
|
|
|
|
class ModernBertLayer(nn.Module):
|
|
def __init__(
|
|
self, config: ModernBertConfig, prefix: str = "", layer_id: int | None = None
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
if layer_id == 0:
|
|
self.attn_norm = nn.Identity()
|
|
else:
|
|
self.attn_norm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
|
)
|
|
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
|
|
self.mlp_norm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
|
)
|
|
self.mlp = ModernBertMLP(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
attn_outputs = self.attn(
|
|
hidden_states=self.attn_norm(hidden_states), position_ids=position_ids
|
|
)
|
|
hidden_states = hidden_states + attn_outputs
|
|
mlp_output = self.mlp(self.mlp_norm(hidden_states))
|
|
hidden_states = hidden_states + mlp_output
|
|
return hidden_states
|
|
|
|
|
|
class ModernBertEncoderLayer(nn.Module):
|
|
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
ModernBertLayer(config=config, layer_id=layer_id)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
for i, layer in enumerate(self.layers):
|
|
hidden_states = layer(hidden_states, position_ids)
|
|
return hidden_states
|
|
|
|
|
|
@support_torch_compile
|
|
@default_pooling_type("CLS")
|
|
class ModernBertModel(nn.Module):
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_prefix={"layers.": "encoder_layer.layers."}
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.config = config
|
|
self.embeddings = ModernBertEmbeddings(config)
|
|
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
|
|
self.final_norm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embeddings.embed_input_ids(input_ids)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
weights = self.hf_to_vllm_mapper.apply(weights)
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embeddings(
|
|
input_ids=input_ids, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
outputs = self.encoder_layer(
|
|
hidden_states=hidden_states,
|
|
position_ids=positions,
|
|
)
|
|
norm_outputs = self.final_norm(outputs)
|
|
return norm_outputs
|
|
|
|
|
|
class ModernBertPooler(Pooler):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__()
|
|
|
|
pooling_type = PoolingType[config.classifier_pooling.upper()]
|
|
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
|
|
self.dense = nn.Linear(
|
|
config.hidden_size, config.hidden_size, config.classifier_bias
|
|
)
|
|
self.act = nn.GELU()
|
|
self.norm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
|
)
|
|
|
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
return self.pooling.get_supported_tasks()
|
|
|
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
return self.pooling.get_pooling_updates(task)
|
|
|
|
def _head(self, pooled_output: torch.Tensor):
|
|
pooled_output = pooled_output.to(self.dense.weight.dtype)
|
|
return self.norm(self.act(self.dense(pooled_output)))
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> torch.Tensor | list[torch.Tensor]:
|
|
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
|
|
|
if isinstance(pooled_output, list):
|
|
pooled_output = [self._head(output) for output in pooled_output]
|
|
else:
|
|
pooled_output = self._head(pooled_output)
|
|
|
|
return pooled_output
|
|
|
|
|
|
@default_pooling_type("CLS")
|
|
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|
is_pooling_model = True
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.config = config
|
|
self.model = ModernBertModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
|
|
)
|
|
self.classifier = nn.Linear(
|
|
config.hidden_size,
|
|
config.num_labels,
|
|
dtype=vllm_config.model_config.head_dtype,
|
|
)
|
|
self.pooling = ModernBertPooler(config)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
assert pooler_config is not None
|
|
|
|
self.pooler = DispatchPooler(
|
|
{
|
|
"token_classify": Pooler.for_token_classify(
|
|
pooler_config, classifier=self.classifier
|
|
),
|
|
"classify": ClassifierPooler(
|
|
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
|
|
),
|
|
"score": ClassifierPooler(
|
|
pooling=self.pooling, classifier=self.classifier, act_fn="score"
|
|
),
|
|
}
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
self_weights = []
|
|
|
|
def weight_filter():
|
|
for name, weight in weights:
|
|
if name.startswith("model."):
|
|
yield name[len("model.") :], weight
|
|
else:
|
|
self_weights.append((name, weight))
|
|
|
|
self.model.load_weights(weight_filter())
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
for name, loaded_weight in self_weights:
|
|
if name.startswith("classifier"):
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
if name.startswith("head"):
|
|
param = params_dict["pooling." + name[len("head") + 1 :]]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return self.model(
|
|
input_ids=input_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
positions=positions,
|
|
)
|
|
|
|
|
|
class ModernBertPredictionHead(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.dense = nn.Linear(
|
|
config.hidden_size, config.hidden_size, bias=config.classifier_bias
|
|
)
|
|
self.act = ACT2FN[config.classifier_activation]
|
|
self.norm = nn.LayerNorm(
|
|
config.hidden_size,
|
|
eps=getattr(config, "norm_eps", 1e-5),
|
|
bias=getattr(config, "norm_bias", True),
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return self.norm(self.act(self.dense(hidden_states)))
|
|
|
|
|
|
@attn_type("encoder_only")
|
|
@default_pooling_type("ALL")
|
|
class ModernBertForTokenClassification(nn.Module):
|
|
is_pooling_model = True
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.head_dtype = vllm_config.model_config.head_dtype
|
|
self.num_labels = config.num_labels
|
|
self.model = ModernBertModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
|
|
)
|
|
self.head = ModernBertPredictionHead(config)
|
|
self.classifier = nn.Linear(
|
|
config.hidden_size, config.num_labels, dtype=self.head_dtype
|
|
)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
assert pooler_config is not None
|
|
|
|
self.pooler = DispatchPooler(
|
|
{
|
|
"token_classify": Pooler.for_token_classify(
|
|
pooler_config=pooler_config
|
|
),
|
|
}
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
|
|
loaded_params = loader.load_weights(weights)
|
|
return loaded_params
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
intermediate_tensors=intermediate_tensors,
|
|
)
|
|
hidden_states = self.head(hidden_states)
|
|
hidden_states = hidden_states.to(self.head_dtype)
|
|
return self.classifier(hidden_states)
|