[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-10-15 19:14:41 +08:00 committed by GitHub
parent d4d1a6024f
commit f54f85129e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 786 additions and 399 deletions

View File

@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py python examples/offline_inference/pooling/embed_matryoshka_fy.py
``` ```
## Multi vector retrieval usage
```bash
python examples/offline_inference/pooling/multi_vector_retrieval.py
```
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash

View File

@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="BAAI/bge-m3",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
print(len(embeds))
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
outputs = llm.encode(prompts, pooling_task="token_embed")
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
multi_vector = output.outputs.data
print(multi_vector.shape)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -40,7 +40,7 @@ def main():
model_impl="terratorch", model_impl="terratorch",
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooling_params = PoolingParams(task="token_classify", activation=False)
pooler_output = llm.encode( pooler_output = llm.encode(
img_prompt, img_prompt,
pooling_params=pooling_params, pooling_params=pooling_params,

View File

@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py
python examples/online_serving/pooling/jinaai_rerank_client.py python examples/online_serving/pooling/jinaai_rerank_client.py
``` ```
## Multi vector retrieval usage
```bash
python examples/online_serving/pooling/multi_vector_retrieval_client.py
```
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Pooling API for multi vector retrieval.
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve BAAI/bge-m3
"""
import argparse
import requests
import torch
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompt = {"model": model_name, "input": prompts}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
for output in pooling_response.json()["data"]:
multi_vector = torch.tensor(output["data"])
print(multi_vector.shape)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -1011,8 +1011,12 @@ class VllmRunner:
req_outputs = self.llm.embed(inputs, *args, **kwargs) req_outputs = self.llm.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs] return [req_output.outputs.embedding for req_output in req_outputs]
def encode(self, prompts: list[str]) -> list[list[float]]: def token_embed(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.encode(prompts) req_outputs = self.llm.encode(prompts, pooling_task="token_embed")
return [req_output.outputs.data for req_output in req_outputs]
def token_classify(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.encode(prompts, pooling_task="token_classify")
return [req_output.outputs.data for req_output in req_outputs] return [req_output.outputs.data for req_output in req_outputs]
def reward(self, prompts: list[str]) -> list[list[float]]: def reward(self, prompts: list[str]) -> list[list[float]]:

View File

@ -63,7 +63,7 @@ def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling # chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+" err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, use_tqdm=False) llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM): def test_score_api(llm: LLM):

View File

@ -35,6 +35,13 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(normalize): def get_outputs(normalize):
outputs = llm.embed( outputs = llm.embed(

View File

@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
] ]
# Multiple PoolingParams should be matched with each prompt # Multiple PoolingParams should be matched with each prompt
outputs = llm.encode(PROMPTS, pooling_params=pooling_params) outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts # Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError): with pytest.raises(ValueError):
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) outputs = llm.encode(
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
)
# Single PoolingParams should be applied to every prompt # Single PoolingParams should be applied to every prompt
single_pooling_params = PoolingParams() single_pooling_params = PoolingParams()
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) outputs = llm.encode(
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
)
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
# pooling_params is None, default params should be applied # pooling_params is None, default params should be applied
outputs = llm.encode(PROMPTS, pooling_params=None) outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)

View File

@ -36,22 +36,23 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(softmax): def get_outputs(activation):
outputs = llm.reward( outputs = llm.reward(
prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
) )
return torch.cat([x.outputs.data for x in outputs]) return torch.cat([x.outputs.data for x in outputs])
default = get_outputs(softmax=None) default = get_outputs(activation=None)
w_softmax = get_outputs(softmax=True) w_activation = get_outputs(activation=True)
wo_softmax = get_outputs(softmax=False) wo_activation = get_outputs(activation=False)
assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." assert torch.allclose(default, w_activation, atol=1e-2), (
assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), ( "Default should use activation."
"wo_softmax should not use softmax."
) )
assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), ( assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
"w_softmax should be close to softmax(wo_softmax)." "wo_activation should not use activation."
)
assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), (
"w_activation should be close to activation(wo_activation)."
) )

View File

@ -17,6 +17,7 @@ from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE, EMBED_DTYPE_TO_TORCH_DTYPE,
EmbeddingResponse, EmbeddingResponse,
PoolingResponse,
) )
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -509,3 +510,20 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
"w_normal should be close to normal(wo_normal)." "w_normal should be close to normal(wo_normal)."
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 384

View File

@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import RerankResponse from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
MODEL_NAME = "BAAI/bge-reranker-base" MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16" DTYPE = "bfloat16"
@ -159,3 +159,20 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
"w_activation should be close to activation(wo_activation)." "w_activation should be close to activation(wo_activation)."
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 1

View File

@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel
from tests.models.utils import check_embeddings_close
@pytest.mark.parametrize(
"model",
["BAAI/bge-m3"],
)
@pytest.mark.parametrize("dtype", ["half"])
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str):
with vllm_runner(
model,
runner="pooling",
max_model_len=None,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(example_prompts)
with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
embedding = output.last_hidden_state[0].float()
# normal
hf_outputs.append(embedding.cpu())
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -93,7 +93,7 @@ def test_embed_models_using_normalize(
], ],
) )
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_reward_models_using_softmax( def test_reward_models_using_activation(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
@ -104,22 +104,64 @@ def test_reward_models_using_softmax(
model, model,
max_model_len=1024, max_model_len=1024,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(softmax=False), pooler_config=PoolerConfig(activation=False),
) as vllm_model: ) as vllm_model:
wo_softmax = vllm_model.encode(example_prompts) wo_activation = vllm_model.reward(example_prompts)
with vllm_runner( with vllm_runner(
model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True) model,
max_model_len=1024,
dtype=dtype,
pooler_config=PoolerConfig(activation=True),
) as vllm_model: ) as vllm_model:
w_softmax = vllm_model.encode(example_prompts) w_activation = vllm_model.reward(example_prompts)
for wo, w in zip(wo_softmax, w_softmax): for wo, w in zip(wo_activation, w_activation):
wo = torch.tensor(wo) wo = torch.tensor(wo)
w = torch.tensor(w) w = torch.tensor(w)
assert not torch.allclose(wo, w, atol=1e-2), ( assert not torch.allclose(wo, w, atol=1e-2), (
"pooler_config softmax is not working" "pooler_config activation is not working"
) )
assert torch.allclose(softmax(wo), w, atol=1e-2), ( assert torch.allclose(softmax(wo), w, atol=1e-2), (
"w_softmax should be close to softmax(wo_softmax)." "w_activation should be close to activation(wo_activation)."
)
@pytest.mark.parametrize(
"model",
[
"intfloat/multilingual-e5-small",
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_multi_vector_retrieval_models_using_normalize(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(normalize=False),
) as vllm_model:
wo_normalize = vllm_model.token_embed(example_prompts)
with vllm_runner(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(normalize=True),
) as vllm_model:
w_normalize = vllm_model.token_embed(example_prompts)
for wo, w in zip(wo_normalize, w_normalize):
assert not torch.allclose(wo, w, atol=1e-2), (
"pooler_config normalize is not working"
)
assert torch.allclose(F.normalize(wo, p=2, dim=-1), w, atol=1e-2), (
"w_normal should be close to normal(wo_normal)."
) )

View File

@ -19,7 +19,7 @@ def test_bert_models(
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.token_classify(example_prompts)
with hf_runner( with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification model, dtype=dtype, auto_cls=AutoModelForTokenClassification
@ -50,7 +50,7 @@ def test_modernbert_models(
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.token_classify(example_prompts)
with hf_runner( with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification model, dtype=dtype, auto_cls=AutoModelForTokenClassification

View File

@ -39,7 +39,7 @@ def _run_test(
max_num_seqs=32, max_num_seqs=32,
default_torch_num_threads=1, default_torch_num_threads=1,
) as vllm_model: ) as vllm_model:
vllm_model.encode(prompt) vllm_model.llm.encode(prompt, pooling_task="token_classify")
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]

View File

@ -30,7 +30,7 @@ class MyGemma2Embedding(nn.Module):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )

View File

@ -93,7 +93,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
out_data_format="b64_json", out_data_format="b64_json",
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooling_params = PoolingParams(activation=False)
with vllm_runner( with vllm_runner(
model_name, model_name,
@ -108,8 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
io_processor_plugin="prithvi_to_tiff", io_processor_plugin="prithvi_to_tiff",
) as llm_runner: ) as llm_runner:
pooler_output = llm_runner.get_llm().encode( pooler_output = llm_runner.get_llm().encode(
img_prompt, img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
pooling_params=pooling_params,
) )
output = pooler_output[0].outputs output = pooler_output[0].outputs

View File

@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest import pytest
from tests.models.utils import EmbedModelInfo from tests.models.utils import EmbedModelInfo
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig, PoolerConfig
EMBEDDING_MODELS = [ EMBEDDING_MODELS = [
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
@ -15,6 +17,15 @@ EMBEDDING_MODELS = [
), ),
] ]
classify_parameters = ["activation"]
embed_parameters = ["dimensions", "normalize"]
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
@dataclass()
class MockModelConfig:
pooler_config: PoolerConfig
def test_task(): def test_task():
pooling_params = PoolingParams() pooling_params = PoolingParams()
@ -24,25 +35,27 @@ def test_task():
pooling_params.verify(task="score") pooling_params.verify(task="score")
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params.verify(task="encode") pooling_params.verify(task="classify")
def test_embed(): def test_embed():
task = "embed" task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(normalize=None)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True) pooling_params = PoolingParams(normalize=True)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False) pooling_params = PoolingParams(normalize=False)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = ["activation", "softmax"] invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters: for p in invalid_parameters:
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@pytest.mark.parametrize("task", ["score", "classify"]) @pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
pooling_params = PoolingParams(activation=None) pooling_params = PoolingParams(activation=None)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=True) pooling_params = PoolingParams(activation=True)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=False) pooling_params = PoolingParams(activation=False)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = ["dimensions", "normalize", "softmax"] invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters: for p in invalid_parameters:
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
def test_encode(): @pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
task = "encode" def test_token_embed(pooling_type: str):
pooling_params = PoolingParams(softmax=None) task = "token_embed"
pooling_params.verify(task=task) model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
)
pooling_params = PoolingParams(softmax=True) pooling_params = PoolingParams(normalize=None)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(softmax=False) pooling_params = PoolingParams(normalize=True)
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False)
pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters
if pooling_type != "STEP":
invalid_parameters = classify_parameters + step_pooling_parameters
invalid_parameters = ["dimensions", "normalize", "activation"]
for p in invalid_parameters: for p in invalid_parameters:
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task) pooling_params.verify(task=task, model_config=model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
def test_token_classify(pooling_type: str):
task = "token_classify"
model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
)
pooling_params = PoolingParams(activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=False)
pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters
if pooling_type != "STEP":
invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters:
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)

View File

@ -951,7 +951,7 @@ class LLM:
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None,
pooling_task: PoolingTask = "encode", pooling_task: PoolingTask | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input """Apply pooling to the hidden states corresponding to the input
@ -986,25 +986,24 @@ class LLM:
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
""" """
if self.supported_tasks == ["encode"] and pooling_task is None: error_str = (
pooling_task = "encode" "pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
if pooling_task is None: if pooling_task is None:
pooling_task = "embed" if "embed" in self.supported_tasks else "encode" raise ValueError(error_str)
logger.warning_once(
"`LLM.encode` is currently using `pooling_task = %s`.\n"
"Please use one of the more specific methods or set the "
"task directly when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="reward"`\n'
" - For similarity scores, use `LLM.score(...)`.",
pooling_task,
)
model_config = self.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
@ -1206,7 +1205,7 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="encode", pooling_task="token_classify",
) )
def _embedding_score( def _embedding_score(

View File

@ -1748,16 +1748,19 @@ async def init_app_state(
else None else None
) )
state.openai_serving_pooling = ( state.openai_serving_pooling = (
OpenAIServingPooling( (
engine_client, OpenAIServingPooling(
state.openai_serving_models, engine_client,
request_logger=request_logger, state.openai_serving_models,
chat_template=resolved_chat_template, supported_tasks=supported_tasks,
chat_template_content_format=args.chat_template_content_format, request_logger=request_logger,
trust_request_chat_template=args.trust_request_chat_template, chat_template=resolved_chat_template,
log_error_stack=args.log_error_stack, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
) )
if "encode" in supported_tasks if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
else None else None
) )
state.openai_serving_embedding = ( state.openai_serving_embedding = (

View File

@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
When using plugins IOProcessor plugins, the actual input is processed When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data by the plugin itself. Hence, we use a generic type for the request data
""" """
softmax: bool = True activation: bool = False
embed_dtype: str = Field( embed_dtype: str = Field(
default="float32", default="float32",
@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
) )
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(task="encode", softmax=self.softmax) return PoolingParams(task="token_classify", activation=self.activation)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]): class IOProcessorResponse(OpenAIBaseModel, Generic[T]):

View File

@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
self.supported_tasks = supported_tasks
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
try: try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
try: try:
pooling_params.verify("encode", self.model_config) pooling_params.verify(pooling_task, self.model_config)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -64,66 +64,6 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids params.requires_token_ids = self.requires_token_ids
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_encode(pooler_config: PoolerConfig):
if pooler_config.pooling_type == "STEP":
return StepPooler()
resolved_config = ResolvedPoolingConfig(
task="encode", pooling_type=PoolingType.ALL
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens( def get_prompt_lens(
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod): class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward_all(
self, self,
@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod): class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward_all(
self, self,
@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod): class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"} return {"token_embed", "token_classify"}
def forward_all( def forward_all(
self, self,
@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod): class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward_all(
self, self,
@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
return self.fn(pooled_data) return self.fn(pooled_data)
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_token_embed(pooler_config: PoolerConfig):
head = TokenEmbeddingPoolerHead()
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_token_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
head = EmbeddingPoolerHead()
return SimplePooler(pooling=pooling, head=head)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
act_fn=act_fn,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class PoolerHead(nn.Module): class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None: def __init__(self, activation: PoolerActivation) -> None:
super().__init__() super().__init__()
@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
super().__init__(activation=PoolerNormalize()) super().__init__(activation=PoolerNormalize())
# Load ST projector if available # Load ST projector if available
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = ( self.projector: nn.Module | None = (
_load_st_projector(vllm_config.model_config) if vllm_config else None _load_st_projector(vllm_config.model_config) if vllm_config else None
@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
return pooled_data return pooled_data
class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
flags = [p.softmax for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
return pooled_data
class SimplePooler(Pooler): class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states. """A layer that pools specific information from hidden states.
@ -513,20 +495,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`. 3. Returns structured results as `PoolerOutput`.
""" """
@classmethod
def from_config(
cls,
pooler_config: ResolvedPoolingConfig,
) -> "SimplePooler":
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
if pooler_config.task == "embed":
head = EmbeddingPoolerHead()
elif pooler_config.task == "encode":
head = RewardPoolerHead()
else:
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
return cls(pooling, head)
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
super().__init__() super().__init__()
@ -549,58 +517,6 @@ class SimplePooler(Pooler):
return pooled_data return pooled_data
class StepPooler(Pooler):
def __init__(
self,
) -> None:
super().__init__()
self.pooling = AllPool()
self.head = RewardPoolerHead()
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class ClassifierPooler(Pooler): class ClassifierPooler(Pooler):
"""A pooling layer for classification tasks. """A pooling layer for classification tasks.
@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
""" """
@staticmethod @staticmethod
def act_fn_for_seq_cls(config: ModelConfig): def act_fn_for_seq_cls(model_config: ModelConfig):
return get_classification_activation_function(config.hf_config) return get_classification_activation_function(model_config.hf_config)
@staticmethod @staticmethod
def act_fn_for_cross_encoder(config: ModelConfig): def act_fn_for_cross_encoder(model_config: ModelConfig):
return get_cross_encoder_activation_function(config.hf_config) return get_cross_encoder_activation_function(model_config.hf_config)
@staticmethod
def resolve_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: PoolerActivation | str | None = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return ClassifierPooler.act_fn_for_seq_cls(model_config)
elif act_fn == "score":
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
else:
raise ValueError(f"act_fn [{act_fn=}] not supported.")
elif act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
else:
assert callable(act_fn)
return act_fn
def __init__( def __init__(
self, self,
pooling: PoolingFn, pooling: PoolingFn,
classifier: ClassifierFn | None, classifier: ClassifierFn | None,
act_fn: PoolerActivation | None = None, act_fn: PoolerActivation | str | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.pooling = pooling self.pooling = pooling
self.classifier = classifier self.classifier = classifier
self.act_fn = act_fn or PoolerClassify() self.act_fn = self.resolve_act_fn(
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias: float | None = ( self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias vllm_config.model_config.pooler_config.logit_bias
) )
@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
return scores return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(nn.Module):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_param: PoolingParams,
) -> torch.Tensor:
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(hidden_states)
else:
scores = hidden_states
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.activation:
scores = self.act_fn(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class DispatchPooler(Pooler): class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task.""" """Dispatches calls to a sub-pooler based on the pooling task."""

View File

@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
}, },
) )
@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler, DispatchPooler,
Pooler, Pooler,
PoolingMethod,
PoolingType,
) )
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
model_config.hidden_size, model_config.hidden_size,
config.num_labels, config.num_labels,
bias=False, bias=False,
params_dtype=torch.float32, params_dtype=vllm_config.model_config.head_dtype,
quant_config=quant_config, quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "score"), prefix=maybe_prefix(prefix, "score"),
) )
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
pooling_type_str = pooler_config.pooling_type
assert pooling_type_str is not None
pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
"classify": ClassifierPooler( pooler_config, classifier=self.score
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "classify": Pooler.for_classify(
pooling=PoolingMethod.from_pooling_type(pooling_type), pooler_config, classifier=self.score, act_fn="classify"
classifier=self._classifier, ),
act_fn=ClassifierPooler.act_fn_for_cross_encoder( "score": Pooler.for_classify(
vllm_config.model_config pooler_config, classifier=self.score, act_fn="score"
),
), ),
} }
) )
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
)
}
) )
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")

View File

@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler( return DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )
@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return DispatchPooler( return DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": SPLADESparsePooler( "embed": SPLADESparsePooler(
mlm_head=self.mlm_head, mlm_head=self.mlm_head,
cls_token_id=cls_id, cls_token_id=cls_id,
@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=self.bert.pooler, pooling=self.bert.pooler,
classifier=self.classifier, classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls( act_fn="classify",
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=self.bert.pooler, pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
} }
) )

View File

@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=self.new.pooler, pooling=self.new.pooler,
classifier=self.classifier, classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls( act_fn="classify",
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=self.new.pooler, pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )

View File

@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )

View File

@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
"classify": Pooler.for_classify(pooler_config, classifier=self.score), pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
} }
) )

View File

@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
if pooler_config is not None: if pooler_config is not None:
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config), "embed": GritLMPooler(vllm_config.model_config),
} }
) )

View File

@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {"token_classify": Pooler.for_token_classify(pooler_config)}
) )
def forward( def forward(

View File

@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify( "classify": Pooler.for_classify(
pooler_config, pooler_config, classifier=self.score, act_fn="classify"
classifier=self.score, ),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
), ),
} }
) )

View File

@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
self.score = JinaVLScorer(vllm_config.model_config) self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
"classify": Pooler.for_classify(pooler_config, classifier=self.score), pooler_config, classifier=self.score
"score": Pooler.for_classify(pooler_config, classifier=self.score), ),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
} }
) )

View File

@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=self.pooling, pooling=self.pooling, classifier=self.classifier, act_fn="classify"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=self.pooling, pooling=self.pooling, classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
} }
) )

View File

@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {"token_classify": Pooler.for_token_classify(pooler_config)}
) )
@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)}) self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)

View File

@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
@default_pooling_type("CLS") @default_pooling_type("CLS")
class RobertaEmbeddingModel(BertEmbeddingModel): class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities."""
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config=pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )

View File

@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {"token_classify": Pooler.for_token_classify(pooler_config)}
) )
def get_input_embeddings( def get_input_embeddings(

View File

@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )
@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )

View File

@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig, PoolerConfig
class PoolingParams( class PoolingParams(
@ -30,7 +30,6 @@ class PoolingParams(
if model support matryoshka representation. if model support matryoshka representation.
activation: Whether to apply activation function to activation: Whether to apply activation function to
the classification outputs. the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
""" """
# --8<-- [start:common-pooling-params] # --8<-- [start:common-pooling-params]
@ -48,32 +47,19 @@ class PoolingParams(
activation: bool | None = None activation: bool | None = None
# --8<-- [end:classification-pooling-params] # --8<-- [end:classification-pooling-params]
## for reward models ## for step pooling models
softmax: bool | None = None
step_tag_id: int | None = None step_tag_id: int | None = None
returned_token_ids: list[int] | None = None returned_token_ids: list[int] | None = None
## Internal use only
task: PoolingTask | None = None task: PoolingTask | None = None
"""Internal use only."""
requires_token_ids: bool = False requires_token_ids: bool = False
"""Internal use only."""
extra_kwargs: dict[str, Any] | None = None extra_kwargs: dict[str, Any] | None = None
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property @property
def all_parameters(self) -> list[str]: def all_parameters(self) -> list[str]:
return [ return ["dimensions", "normalize", "activation"]
"dimensions",
"normalize",
"activation",
"softmax",
"step_tag_id",
"returned_token_ids",
]
@property @property
def valid_parameters(self): def valid_parameters(self):
@ -81,7 +67,8 @@ class PoolingParams(
"embed": ["dimensions", "normalize"], "embed": ["dimensions", "normalize"],
"classify": ["activation"], "classify": ["activation"],
"score": ["activation"], "score": ["activation"],
"encode": ["softmax", "step_tag_id", "returned_token_ids"], "token_embed": ["dimensions", "normalize"],
"token_classify": ["activation"],
} }
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
@ -100,7 +87,6 @@ class PoolingParams(
# NOTE: Task validation needs to done against the model instance, # NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included # which is not available in model config. So, it's not included
# in this method # in this method
self._merge_default_parameters(model_config) self._merge_default_parameters(model_config)
self._set_default_parameters(model_config) self._set_default_parameters(model_config)
self._verify_valid_parameters() self._verify_valid_parameters()
@ -125,8 +111,34 @@ class PoolingParams(
if getattr(self, k, None) is None: if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k)) setattr(self, k, getattr(pooler_config, k))
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
):
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
if pooler_config.pooling_type != "STEP":
invalid_parameters = []
for k in step_pooling_parameters:
if getattr(self, k, None) is not None:
invalid_parameters.append(k)
if invalid_parameters:
raise ValueError(
f"Task {self.task} only supports {valid_parameters} "
f"parameters, does not support "
f"{invalid_parameters} parameters"
)
else:
for k in step_pooling_parameters:
if getattr(pooler_config, k, None) is None:
continue
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
def _set_default_parameters(self, model_config: Optional["ModelConfig"]): def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
if self.task == "embed": if self.task in ["embed", "token_embed"]:
if self.normalize is None: if self.normalize is None:
self.normalize = True self.normalize = True
@ -150,13 +162,9 @@ class PoolingParams(
elif self.dimensions < 1: elif self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0") raise ValueError("Dimensions must be greater than 0")
elif self.task in ["classify", "score"]: elif self.task in ["classify", "score", "token_classify"]:
if self.activation is None: if self.activation is None:
self.activation = True self.activation = True
elif self.task == "encode":
if self.softmax is None:
self.softmax = True
else: else:
raise ValueError(f"Unknown pooling task: {self.task}") raise ValueError(f"Unknown pooling task: {self.task}")
@ -185,7 +193,6 @@ class PoolingParams(
f"normalize={self.normalize}, " f"normalize={self.normalize}, "
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"activation={self.activation}, " f"activation={self.activation}, "
f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, " f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, " f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, " f"requires_token_ids={self.requires_token_ids}, "

View File

@ -5,7 +5,7 @@ from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"] GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask) GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed", "classify", "score"] PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
POOLING_TASKS = get_args(PoolingTask) POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask] SupportedTask = Literal[GenerationTask, PoolingTask]

View File

@ -1926,15 +1926,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
supported_tasks = list(model.pooler.get_supported_tasks()) supported_tasks = list(model.pooler.get_supported_tasks())
if ( if self.scheduler_config.chunked_prefill_enabled:
self.scheduler_config.chunked_prefill_enabled if "token_embed" in supported_tasks:
and "encode" in supported_tasks supported_tasks.remove("token_embed")
): if "token_classify" in supported_tasks:
supported_tasks.remove("encode") supported_tasks.remove("token_classify")
logger.debug_once( logger.debug_once(
"Chunked prefill is not supported with " "Chunked prefill is not supported with "
"encode task which using ALL pooling. " "token_embed and token_classify tasks "
"which using ALL pooling. "
"Please turn off chunked prefill by " "Please turn off chunked prefill by "
"`--no-enable-chunked-prefill` before using it." "`--no-enable-chunked-prefill` before using it."
) )