From 5f696c33b1fbf33fe91ecdd958874b9dd52f79b4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 18 Sep 2025 23:22:01 +0800 Subject: [PATCH] [New Model] Support BertForTokenClassification / Named Entity Recognition (NER) task (#24872) Signed-off-by: wang.yuqi Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- docs/models/supported_models.md | 11 +++ examples/offline_inference/pooling/README.md | 8 ++- examples/offline_inference/pooling/ner.py | 54 ++++++++++++++ examples/online_serving/pooling/README.md | 6 ++ examples/online_serving/pooling/ner.py | 71 +++++++++++++++++++ .../pooling/test_token_classification.py | 39 ++++++++++ tests/models/registry.py | 1 + vllm/entrypoints/llm.py | 4 ++ vllm/model_executor/models/bert.py | 52 ++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/v1/attention/backends/flex_attention.py | 12 +++- 11 files changed, 257 insertions(+), 2 deletions(-) create mode 100644 examples/offline_inference/pooling/ner.py create mode 100644 examples/online_serving/pooling/ner.py create mode 100644 tests/models/language/pooling/test_token_classification.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7aeaeca97699c..b67ebcbe3c81a 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -554,6 +554,17 @@ If your model is not in the above list, we will try to automatically convert the For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. +#### Token Classification + +These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API. + +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| +| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | + +!!! note + Named Entity Recognition (NER) usage, please refer to , . + [](){ #supported-mm-models } ## List of Multimodal Language Models diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md index 8693f5e08e0ba..79afbd9cfac47 100644 --- a/examples/offline_inference/pooling/README.md +++ b/examples/offline_inference/pooling/README.md @@ -26,8 +26,14 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py python examples/offline_inference/pooling/embed_matryoshka_fy.py ``` +## Named Entity Recognition (NER) usage + +```bash +python examples/offline_inference/pooling/ner.py +``` + ## Qwen3 reranker usage ```bash -python qwen3_reranker.py +python examples/offline_inference/pooling/qwen3_reranker.py ``` diff --git a/examples/offline_inference/pooling/ner.py b/examples/offline_inference/pooling/ner.py new file mode 100644 index 0000000000000..f18742fac0d54 --- /dev/null +++ b/examples/offline_inference/pooling/ner.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="boltuix/NeuroBERT-NER", + runner="pooling", + enforce_eager=True, + trust_remote_code=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + ] + + # Create an LLM. + llm = LLM(**vars(args)) + tokenizer = llm.get_tokenizer() + label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label + + # Run inference + outputs = llm.encode(prompts) + + for prompt, output in zip(prompts, outputs): + logits = output.outputs.data + predictions = logits.argmax(dim=-1) + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids) + labels = [label_map[p.item()] for p in predictions] + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md index f7926542202d6..2c271b6a32bc2 100644 --- a/examples/online_serving/pooling/README.md +++ b/examples/online_serving/pooling/README.md @@ -12,6 +12,12 @@ python examples/online_serving/pooling/cohere_rerank_client.py python examples/online_serving/pooling/jinaai_rerank_client.py ``` +## Named Entity Recognition (NER) usage + +```bash +python examples/online_serving/pooling/ner.py +``` + ## Openai chat embedding for multimodal usage ```bash diff --git a/examples/online_serving/pooling/ner.py b/examples/online_serving/pooling/ner.py new file mode 100644 index 0000000000000..9ec2bd45a0fe5 --- /dev/null +++ b/examples/online_serving/pooling/ner.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +""" +Example online usage of Pooling API for Named Entity Recognition (NER). + +Run `vllm serve --runner pooling` +to start up the server in vLLM. e.g. + +vllm serve boltuix/NeuroBERT-NER +""" + +import argparse + +import requests +import torch + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER") + + return parser.parse_args() + + +def main(args): + from transformers import AutoConfig, AutoTokenizer + + api_url = f"http://{args.host}:{args.port}/pooling" + model_name = args.model + + # Load tokenizer and config + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + label_map = config.id2label + + # Input text + text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + prompt = {"model": model_name, "input": text} + + pooling_response = post_http_request(prompt=prompt, api_url=api_url) + + # Run inference + output = pooling_response.json()["data"][0] + logits = torch.tensor(output["data"]) + predictions = logits.argmax(dim=-1) + inputs = tokenizer(text, return_tensors="pt") + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + labels = [label_map[p.item()] for p in predictions] + assert len(tokens) == len(predictions) + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py new file mode 100644 index 0000000000000..fd5e48a8b1449 --- /dev/null +++ b/tests/models/language/pooling/test_token_classification.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForTokenClassification + +from tests.models.utils import softmax + + +@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"]) +# The float32 is required for this tiny model to pass the test. +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForTokenClassification) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, 1e-2) diff --git a/tests/models/registry.py b/tests/models/registry.py index 93aa9d4025498..e9cc5170ade74 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -414,6 +414,7 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { # [Cross-encoder] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 trust_remote_code=True, hf_overrides={ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63e9478612bb1..df6b16c73d6e7 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -943,6 +943,10 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ + + if self.supported_tasks == ["encode"] and pooling_task is None: + pooling_task = "encode" + if pooling_task is None: if "embed" in self.supported_tasks: pooling_task = "embed" diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c07e5364814ac..ee32587f6b1b4 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -611,3 +611,55 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) + + +@default_pooling_type("ALL") +class BertForTokenClassification(nn.Module): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.head_dtype = vllm_config.model_config.head_dtype + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=self.head_dtype) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + }) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if token_type_ids is not None: + assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + + hidden_states = self.bert(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + hidden_states = hidden_states.to(self.head_dtype) + return self.classifier(hidden_states) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 707b57106e6d9..1382fd9e93ea3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -193,6 +193,7 @@ _EMBEDDING_MODELS = { _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "BertForTokenClassification": ("bert", "BertForTokenClassification"), "GteNewForSequenceClassification": ("bert_with_rope", "GteNewForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index cb983494216a7..662d3984554ad 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -720,6 +720,15 @@ class FlexAttentionImpl(AttentionImpl): (query, key, value), ) + query = query[:, :, :num_actual_tokens, :] + if ((key_tensor.size(-2) > num_actual_tokens) + or (value_tensor.size(-2) > num_actual_tokens)): + # In the encoder-only model with torch.compile, + # qkv might be padded, which might cause exception. + # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 + key_tensor = key_tensor[:, :, :num_actual_tokens, :] + value_tensor = value_tensor[:, :, :num_actual_tokens, :] + else: assert self.attn_type == AttentionType.DECODER key_cache, value_cache = kv_cache.unbind(0) @@ -744,7 +753,8 @@ class FlexAttentionImpl(AttentionImpl): (query, key_cache, value_cache), ) - query = query[:, :, :num_actual_tokens, :] + query = query[:, :, :num_actual_tokens, :] + # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2)