vllm/vllm/model_executor/models/modernbert.py
wang.yuqi f4b76056ee
Improve enable chunked_prefill & prefix_caching logic. (#26623)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2025-11-27 22:05:48 -08:00

452 lines
16 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Set
import torch
from torch import nn
from transformers import ModernBertConfig
from transformers.activations import ACT2FN
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler,
Pooler,
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
from .interfaces_base import attn_type, default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class ModernBertEmbeddings(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
eps = (
getattr(config, "norm_eps", None)
or getattr(config, "layer_norm_eps", None)
or 1e-5
)
self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if inputs_embeds is not None:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids)
embeddings = self.norm(inputs_embeds)
return embeddings
class ModernBertRotaryEmbedding(RotaryEmbedding):
def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float):
super().__init__(
head_size=head_size,
rotary_dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
is_neox_style=True,
dtype=torch.float16,
)
self.config = config
class ModernBertAttention(nn.Module):
def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
assert self.num_heads % tp_size == 0
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.scaling = self.head_dim**-0.5
self.Wqkv = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.num_heads,
bias=config.attention_bias,
)
sliding_window = None
if layer_id % config.global_attn_every_n_layers != 0:
sliding_window = config.local_attention // 2
rope_theta = (
config.local_rope_theta
if config.local_rope_theta is not None
else config.global_rope_theta
)
else:
rope_theta = config.global_rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(
config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta
)
self.attn = EncoderOnlyAttention(
self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
per_layer_sliding_window=sliding_window,
)
self.Wo = RowParallelLinear(
config.hidden_size, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_outputs = self.attn(q, k, v)
hidden_states = attn_outputs
hidden_states, _ = self.Wo(hidden_states)
return hidden_states
class ModernBertMLP(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(
config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias
)
self.act = nn.GELU()
self.Wo = RowParallelLinear(
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.act(input) * gate)[0]
class ModernBertLayer(nn.Module):
def __init__(
self, config: ModernBertConfig, prefix: str = "", layer_id: int | None = None
):
super().__init__()
self.config = config
if layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
self.mlp_norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
self.mlp = ModernBertMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
attn_outputs = self.attn(
hidden_states=self.attn_norm(hidden_states), position_ids=position_ids
)
hidden_states = hidden_states + attn_outputs
mlp_output = self.mlp(self.mlp_norm(hidden_states))
hidden_states = hidden_states + mlp_output
return hidden_states
class ModernBertEncoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.layers = nn.ModuleList(
[
ModernBertLayer(config=config, layer_id=layer_id)
for layer_id in range(config.num_hidden_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, position_ids)
return hidden_states
@support_torch_compile
@default_pooling_type("CLS")
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."}
)
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
self.final_norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(
input_ids=input_ids, inputs_embeds=inputs_embeds
)
outputs = self.encoder_layer(
hidden_states=hidden_states,
position_ids=positions,
)
norm_outputs = self.final_norm(outputs)
return norm_outputs
class ModernBertPooler(Pooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()]
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias
)
self.act = nn.GELU()
self.norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = pooled_output.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_output)))
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
@default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
)
self.classifier = nn.Linear(
config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype,
)
self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="score"
),
}
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("model."):
yield name[len("model.") :], weight
else:
self_weights.append((name, weight))
self.model.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name.startswith("head"):
param = params_dict["pooling." + name[len("head") + 1 :]]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def forward(
self,
input_ids: torch.LongTensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions,
)
class ModernBertPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, bias=config.classifier_bias
)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(
config.hidden_size,
eps=getattr(config, "norm_eps", 1e-5),
bias=getattr(config, "norm_bias", True),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
@attn_type("encoder_only")
@default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype
self.num_labels = config.num_labels
self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
)
self.head = ModernBertPredictionHead(config)
self.classifier = nn.Linear(
config.hidden_size, config.num_labels, dtype=self.head_dtype
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
hidden_states = self.head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states)