# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Set import torch from torch import nn from transformers import ModernBertConfig from transformers.activations import ACT2FN from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.pooler import ( ClassifierPooler, DispatchPooler, Pooler, PoolingMethod, PoolingParamsUpdate, PoolingType, ) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class ModernBertEmbeddings(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.tok_embeddings = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) eps = ( getattr(config, "norm_eps", None) or getattr(config, "layer_norm_eps", None) or 1e-5 ) self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is not None: return self.norm(inputs_embeds) else: inputs_embeds = self.tok_embeddings(input_ids) embeddings = self.norm(inputs_embeds) return embeddings class ModernBertRotaryEmbedding(RotaryEmbedding): def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float): super().__init__( head_size=head_size, rotary_dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, is_neox_style=True, dtype=torch.float16, ) self.config = config class ModernBertAttention(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: int | None = None): super().__init__() self.config = config self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.layer_id = layer_id self.deterministic_flash_attn = config.deterministic_flash_attn self.num_heads = config.num_attention_heads assert self.num_heads % tp_size == 0 self.head_dim = config.hidden_size // config.num_attention_heads self.all_head_size = self.head_dim * self.num_heads self.scaling = self.head_dim**-0.5 self.Wqkv = QKVParallelLinear( config.hidden_size, self.head_dim, self.num_heads, bias=config.attention_bias, ) sliding_window = None if layer_id % config.global_attn_every_n_layers != 0: sliding_window = config.local_attention // 2 rope_theta = ( config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta ) else: rope_theta = config.global_rope_theta self.rotary_emb = ModernBertRotaryEmbedding( config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta ) self.attn = EncoderOnlyAttention( self.num_heads, self.head_dim, self.scaling, prefix=f"{layer_id}.attn", per_layer_sliding_window=sliding_window, ) self.Wo = RowParallelLinear( config.hidden_size, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) q, k, v = qkv.split([self.all_head_size] * 3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) attn_outputs = self.attn(q, k, v) hidden_states = attn_outputs hidden_states, _ = self.Wo(hidden_states) return hidden_states class ModernBertMLP(nn.Module): def __init__(self, config: ModernBertConfig): super().__init__() self.config = config self.Wi = nn.Linear( config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias ) self.act = nn.GELU() self.Wo = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=config.mlp_bias ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) return self.Wo(self.act(input) * gate)[0] class ModernBertLayer(nn.Module): def __init__( self, config: ModernBertConfig, prefix: str = "", layer_id: int | None = None ): super().__init__() self.config = config if layer_id == 0: self.attn_norm = nn.Identity() else: self.attn_norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) self.attn = ModernBertAttention(config=config, layer_id=layer_id) self.mlp_norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) self.mlp = ModernBertMLP(config) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: attn_outputs = self.attn( hidden_states=self.attn_norm(hidden_states), position_ids=position_ids ) hidden_states = hidden_states + attn_outputs mlp_output = self.mlp(self.mlp_norm(hidden_states)) hidden_states = hidden_states + mlp_output return hidden_states class ModernBertEncoderLayer(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.layers = nn.ModuleList( [ ModernBertLayer(config=config, layer_id=layer_id) for layer_id in range(config.num_hidden_layers) ] ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states, position_ids) return hidden_states @support_torch_compile @default_pooling_type("CLS") class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"layers.": "encoder_layer.layers."} ) def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_config self.config = config self.embeddings = ModernBertEmbeddings(config) self.encoder_layer = ModernBertEncoderLayer(vllm_config) self.final_norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings.get_input_embeddings(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings( input_ids=input_ids, inputs_embeds=inputs_embeds ) outputs = self.encoder_layer( hidden_states=hidden_states, position_ids=positions, ) norm_outputs = self.final_norm(outputs) return norm_outputs class ModernBertPooler(Pooler): def __init__(self, config: ModernBertConfig): super().__init__() pooling_type = PoolingType[config.classifier_pooling.upper()] self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.dense = nn.Linear( config.hidden_size, config.hidden_size, config.classifier_bias ) self.act = nn.GELU() self.norm = nn.LayerNorm( config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) def get_supported_tasks(self) -> Set[PoolingTask]: return self.pooling.get_supported_tasks() def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return self.pooling.get_pooling_updates(task) def _head(self, pooled_output: torch.Tensor): pooled_output = pooled_output.to(self.dense.weight.dtype) return self.norm(self.act(self.dense(pooled_output))) def forward( self, hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> torch.Tensor | list[torch.Tensor]: pooled_output = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_output, list): pooled_output = [self._head(output) for output in pooled_output] else: pooled_output = self._head(pooled_output) return pooled_output @default_pooling_type("CLS") class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config self.model = ModernBertModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") ) self.classifier = nn.Linear( config.hidden_size, config.num_labels, dtype=vllm_config.model_config.head_dtype, ) self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( { "token_classify": Pooler.for_token_classify( pooler_config, classifier=self.classifier ), "classify": ClassifierPooler( pooling=self.pooling, classifier=self.classifier, act_fn="classify" ), "score": ClassifierPooler( pooling=self.pooling, classifier=self.classifier, act_fn="score" ), } ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): self_weights = [] def weight_filter(): for name, weight in weights: if name.startswith("model."): yield name[len("model.") :], weight else: self_weights.append((name, weight)) self.model.load_weights(weight_filter()) params_dict = dict(self.named_parameters()) for name, loaded_weight in self_weights: if name.startswith("classifier"): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if name.startswith("head"): param = params_dict["pooling." + name[len("head") + 1 :]] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) def forward( self, input_ids: torch.LongTensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, positions=positions, ) class ModernBertPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.config = config self.dense = nn.Linear( config.hidden_size, config.hidden_size, bias=config.classifier_bias ) self.act = ACT2FN[config.classifier_activation] self.norm = nn.LayerNorm( config.hidden_size, eps=getattr(config, "norm_eps", 1e-5), bias=getattr(config, "norm_bias", True), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(self.act(self.dense(hidden_states))) @default_pooling_type("ALL") class ModernBertForTokenClassification(nn.Module): is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.head_dtype = vllm_config.model_config.head_dtype self.num_labels = config.num_labels self.model = ModernBertModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") ) self.head = ModernBertPredictionHead(config) self.classifier = nn.Linear( config.hidden_size, config.num_labels, dtype=self.head_dtype ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( { "token_classify": Pooler.for_token_classify( pooler_config=pooler_config ), } ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self, skip_prefixes=["drop"]) loaded_params = loader.load_weights(weights) return loaded_params def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, ) hidden_states = self.head(hidden_states) hidden_states = hidden_states.to(self.head_dtype) return self.classifier(hidden_states)