[Model] Add support for ModernBertForTokenClassification (#26340)

Signed-off-by: Antoine Recanati Le Goat <antoine.recanati@sancare.fr>
Signed-off-by: antrec <antoine.recanati@gmail.com>
Co-authored-by: Antoine Recanati Le Goat <antoine.recanati@sancare.fr>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
antrec 2025-10-07 16:29:19 +02:00 committed by GitHub
parent 41f1cf38f2
commit 6f59beaf0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 112 additions and 2 deletions

View File

@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| |--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | | `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
!!! note !!! note
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>. Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.

View File

@ -11,7 +11,38 @@ from tests.models.utils import softmax
# The float32 is required for this tiny model to pass the test. # The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode @torch.inference_mode
def test_models( def test_bert_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
hf_outputs.append(softmax(output.logits[0]))
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, 1e-2)
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_modernbert_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,

View File

@ -527,6 +527,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"ModernBertForSequenceClassification": _HfExamplesInfo( "ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base" "Alibaba-NLP/gte-reranker-modernbert-base"
), ),
"ModernBertForTokenClassification": _HfExamplesInfo(
"disham993/electrical-ner-ModernBERT-base"
),
"RobertaForSequenceClassification": _HfExamplesInfo( "RobertaForSequenceClassification": _HfExamplesInfo(
"cross-encoder/quora-roberta-base" "cross-encoder/quora-roberta-base"
), ),

View File

@ -6,6 +6,7 @@ from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import ModernBertConfig from transformers import ModernBertConfig
from transformers.activations import ACT2FN
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
@ -29,7 +30,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
from .utils import WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class ModernBertEmbeddings(nn.Module): class ModernBertEmbeddings(nn.Module):
@ -379,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
positions=positions, positions=positions,
) )
class ModernBertPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, bias=config.classifier_bias
)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(
config.hidden_size,
eps=getattr(config, "norm_eps", 1e-5),
bias=getattr(config, "norm_bias", True),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
@default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype
self.num_labels = config.num_labels
self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
)
self.head = ModernBertPredictionHead(config)
self.classifier = nn.Linear(
config.hidden_size, config.num_labels, dtype=self.head_dtype
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
}
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
hidden_states = self.head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states)

View File

@ -225,6 +225,10 @@ _CROSS_ENCODER_MODELS = {
"modernbert", "modernbert",
"ModernBertForSequenceClassification", "ModernBertForSequenceClassification",
), ),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ( "XLMRobertaForSequenceClassification": (
"roberta", "roberta",