From 6f59beaf0b1f2ec922c027bcee5bf52042af7430 Mon Sep 17 00:00:00 2001 From: antrec Date: Tue, 7 Oct 2025 16:29:19 +0200 Subject: [PATCH] [Model] Add support for ModernBertForTokenClassification (#26340) Signed-off-by: Antoine Recanati Le Goat Signed-off-by: antrec Co-authored-by: Antoine Recanati Le Goat Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/models/supported_models.md | 1 + .../pooling/test_token_classification.py | 33 ++++++++- tests/models/registry.py | 3 + vllm/model_executor/models/modernbert.py | 73 ++++++++++++++++++- vllm/model_executor/models/registry.py | 4 + 5 files changed, 112 insertions(+), 2 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 60fe5b887952..10ccd73d8f30 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| | `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | +| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ | !!! note Named Entity Recognition (NER) usage, please refer to , . diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index 4849f1ec4d36..f72dfb46d9fd 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -11,7 +11,38 @@ from tests.models.utils import softmax # The float32 is required for this tiny model to pass the test. @pytest.mark.parametrize("dtype", ["float"]) @torch.inference_mode -def test_models( +def test_bert_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, 1e-2) + + +@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"]) +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_modernbert_models( hf_runner, vllm_runner, example_prompts, diff --git a/tests/models/registry.py b/tests/models/registry.py index e1d9f1d1dd74..297e8854c5bd 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -527,6 +527,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { "ModernBertForSequenceClassification": _HfExamplesInfo( "Alibaba-NLP/gte-reranker-modernbert-base" ), + "ModernBertForTokenClassification": _HfExamplesInfo( + "disham993/electrical-ner-ModernBERT-base" + ), "RobertaForSequenceClassification": _HfExamplesInfo( "cross-encoder/quora-roberta-base" ), diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 2e3b76aaaabc..58e2acb8ce92 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -6,6 +6,7 @@ from typing import Optional, Union import torch from torch import nn from transformers import ModernBertConfig +from transformers.activations import ACT2FN from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile @@ -29,7 +30,7 @@ from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding from .interfaces_base import default_pooling_type -from .utils import WeightsMapper, maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class ModernBertEmbeddings(nn.Module): @@ -379,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): inputs_embeds=inputs_embeds, positions=positions, ) + + +class ModernBertPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.dense = nn.Linear( + config.hidden_size, config.hidden_size, bias=config.classifier_bias + ) + self.act = ACT2FN[config.classifier_activation] + self.norm = nn.LayerNorm( + config.hidden_size, + eps=getattr(config, "norm_eps", 1e-5), + bias=getattr(config, "norm_bias", True), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(self.act(self.dense(hidden_states))) + + +@default_pooling_type("ALL") +class ModernBertForTokenClassification(nn.Module): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.head_dtype = vllm_config.model_config.head_dtype + self.num_labels = config.num_labels + self.model = ModernBertModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") + ) + self.head = ModernBertPredictionHead(config) + self.classifier = nn.Linear( + config.hidden_size, config.num_labels, dtype=self.head_dtype + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + } + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self, skip_prefixes=["drop"]) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) + hidden_states = self.head(hidden_states) + hidden_states = hidden_states.to(self.head_dtype) + return self.classifier(hidden_states) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7c324b7e7872..c680d29923f8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -225,6 +225,10 @@ _CROSS_ENCODER_MODELS = { "modernbert", "ModernBertForSequenceClassification", ), + "ModernBertForTokenClassification": ( + "modernbert", + "ModernBertForTokenClassification", + ), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ( "roberta",