mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[Model] Add support for ModernBertForTokenClassification (#26340)
Signed-off-by: Antoine Recanati Le Goat <antoine.recanati@sancare.fr> Signed-off-by: antrec <antoine.recanati@gmail.com> Co-authored-by: Antoine Recanati Le Goat <antoine.recanati@sancare.fr> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
41f1cf38f2
commit
6f59beaf0b
@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
|
||||
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
|
||||
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
|
||||
|
||||
@ -11,7 +11,38 @@ from tests.models.utils import softmax
|
||||
# The float32 is required for this tiny model to pass the test.
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@torch.inference_mode
|
||||
def test_models(
|
||||
def test_bert_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||
) as hf_model:
|
||||
tokenizer = hf_model.tokenizer
|
||||
hf_outputs = []
|
||||
for prompt in example_prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = hf_model.wrap_device(inputs)
|
||||
output = hf_model.model(**inputs)
|
||||
hf_outputs.append(softmax(output.logits[0]))
|
||||
|
||||
# check logits difference
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output).cpu().float()
|
||||
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||
assert torch.allclose(hf_output, vllm_output, 1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@torch.inference_mode
|
||||
def test_modernbert_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
|
||||
@ -527,6 +527,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo(
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base"
|
||||
),
|
||||
"ModernBertForTokenClassification": _HfExamplesInfo(
|
||||
"disham993/electrical-ner-ModernBERT-base"
|
||||
),
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo(
|
||||
"cross-encoder/quora-roberta-base"
|
||||
),
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import ModernBertConfig
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -29,7 +30,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
class ModernBertEmbeddings(nn.Module):
|
||||
@ -379,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
|
||||
class ModernBertPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size, config.hidden_size, bias=config.classifier_bias
|
||||
)
|
||||
self.act = ACT2FN[config.classifier_activation]
|
||||
self.norm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=getattr(config, "norm_eps", 1e-5),
|
||||
bias=getattr(config, "norm_bias", True),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(self.act(self.dense(hidden_states)))
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
class ModernBertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
self.num_labels = config.num_labels
|
||||
self.model = ModernBertModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
|
||||
)
|
||||
self.head = ModernBertPredictionHead(config)
|
||||
self.classifier = nn.Linear(
|
||||
config.hidden_size, config.num_labels, dtype=self.head_dtype
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
hidden_states = self.head(hidden_states)
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
return self.classifier(hidden_states)
|
||||
|
||||
@ -225,6 +225,10 @@ _CROSS_ENCODER_MODELS = {
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
),
|
||||
"ModernBertForTokenClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForTokenClassification",
|
||||
),
|
||||
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
|
||||
"XLMRobertaForSequenceClassification": (
|
||||
"roberta",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user