[Model] Add support for ModernBertForTokenClassification (#26340)

Signed-off-by: Antoine Recanati Le Goat <antoine.recanati@sancare.fr>
Signed-off-by: antrec <antoine.recanati@gmail.com>
Co-authored-by: Antoine Recanati Le Goat <antoine.recanati@sancare.fr>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
antrec 2025-10-07 16:29:19 +02:00 committed by GitHub
parent 41f1cf38f2
commit 6f59beaf0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 112 additions and 2 deletions

View File

@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
!!! note
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.

View File

@ -11,7 +11,38 @@ from tests.models.utils import softmax
# The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_models(
def test_bert_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
hf_outputs.append(softmax(output.logits[0]))
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, 1e-2)
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_modernbert_models(
hf_runner,
vllm_runner,
example_prompts,

View File

@ -527,6 +527,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base"
),
"ModernBertForTokenClassification": _HfExamplesInfo(
"disham993/electrical-ner-ModernBERT-base"
),
"RobertaForSequenceClassification": _HfExamplesInfo(
"cross-encoder/quora-roberta-base"
),

View File

@ -6,6 +6,7 @@ from typing import Optional, Union
import torch
from torch import nn
from transformers import ModernBertConfig
from transformers.activations import ACT2FN
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
@ -29,7 +30,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
from .utils import WeightsMapper, maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class ModernBertEmbeddings(nn.Module):
@ -379,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
inputs_embeds=inputs_embeds,
positions=positions,
)
class ModernBertPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, bias=config.classifier_bias
)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(
config.hidden_size,
eps=getattr(config, "norm_eps", 1e-5),
bias=getattr(config, "norm_bias", True),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
@default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype
self.num_labels = config.num_labels
self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
)
self.head = ModernBertPredictionHead(config)
self.classifier = nn.Linear(
config.hidden_size, config.num_labels, dtype=self.head_dtype
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
}
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
hidden_states = self.head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states)

View File

@ -225,6 +225,10 @@ _CROSS_ENCODER_MODELS = {
"modernbert",
"ModernBertForSequenceClassification",
),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": (
"roberta",