From f54f85129e4665c16f39b097463c3c350ef34210 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 15 Oct 2025 19:14:41 +0800 Subject: [PATCH] [Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370) Signed-off-by: wang.yuqi --- examples/offline_inference/pooling/README.md | 6 + .../pooling/multi_vector_retrieval.py | 56 +++ .../prithvi_geospatial_mae_io_processor.py | 2 +- examples/online_serving/pooling/README.md | 6 + .../pooling/multi_vector_retrieval_client.py | 54 +++ tests/conftest.py | 8 +- .../entrypoints/pooling/llm/test_classify.py | 2 +- .../entrypoints/pooling/llm/test_embedding.py | 7 + tests/entrypoints/pooling/llm/test_encode.py | 12 +- tests/entrypoints/pooling/llm/test_reward.py | 23 +- .../pooling/openai/test_embedding.py | 18 + .../entrypoints/pooling/openai/test_rerank.py | 19 +- .../pooling/test_multi_vector_retrieval.py | 45 ++ .../test_pooler_config_init_behaviour.py | 58 ++- .../pooling/test_token_classification.py | 4 +- .../multimodal/pooling/test_prithvi_mae.py | 2 +- .../my_gemma_embedding.py | 2 +- .../test_io_processor_plugins.py | 5 +- tests/test_pooling_params.py | 93 +++- vllm/entrypoints/llm.py | 37 +- vllm/entrypoints/openai/api_server.py | 21 +- vllm/entrypoints/openai/protocol.py | 4 +- vllm/entrypoints/openai/serving_pooling.py | 14 +- vllm/model_executor/layers/pooler.py | 422 +++++++++++------- vllm/model_executor/models/adapters.py | 42 +- vllm/model_executor/models/bert.py | 22 +- vllm/model_executor/models/bert_with_rope.py | 14 +- vllm/model_executor/models/clip.py | 2 +- vllm/model_executor/models/gpt2.py | 11 +- vllm/model_executor/models/gritlm.py | 2 +- vllm/model_executor/models/internlm2.py | 2 +- vllm/model_executor/models/jamba.py | 10 +- vllm/model_executor/models/jina_vl.py | 12 +- vllm/model_executor/models/modernbert.py | 20 +- vllm/model_executor/models/qwen2_rm.py | 6 +- vllm/model_executor/models/roberta.py | 26 +- vllm/model_executor/models/terratorch.py | 2 +- .../models/transformers_pooling.py | 18 +- vllm/pooling_params.py | 61 +-- vllm/tasks.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 13 +- 41 files changed, 786 insertions(+), 399 deletions(-) create mode 100644 examples/offline_inference/pooling/multi_vector_retrieval.py create mode 100644 examples/online_serving/pooling/multi_vector_retrieval_client.py create mode 100644 tests/models/language/pooling/test_multi_vector_retrieval.py diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md index 79afbd9cfac4..7c535e91afac 100644 --- a/examples/offline_inference/pooling/README.md +++ b/examples/offline_inference/pooling/README.md @@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py python examples/offline_inference/pooling/embed_matryoshka_fy.py ``` +## Multi vector retrieval usage + +```bash +python examples/offline_inference/pooling/multi_vector_retrieval.py +``` + ## Named Entity Recognition (NER) usage ```bash diff --git a/examples/offline_inference/pooling/multi_vector_retrieval.py b/examples/offline_inference/pooling/multi_vector_retrieval.py new file mode 100644 index 000000000000..8b8892117d37 --- /dev/null +++ b/examples/offline_inference/pooling/multi_vector_retrieval.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="BAAI/bge-m3", + runner="pooling", + enforce_eager=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass runner="pooling" for embedding models + llm = LLM(**vars(args)) + + # Generate embedding. The output is a list of EmbeddingRequestOutputs. + outputs = llm.embed(prompts) + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + print(len(embeds)) + + # Generate embedding for each token. The output is a list of PoolingRequestOutput. + outputs = llm.encode(prompts, pooling_task="token_embed") + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + multi_vector = output.outputs.data + print(multi_vector.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 418c40645f9f..6c47b5715438 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -40,7 +40,7 @@ def main(): model_impl="terratorch", ) - pooling_params = PoolingParams(task="encode", softmax=False) + pooling_params = PoolingParams(task="token_classify", activation=False) pooler_output = llm.encode( img_prompt, pooling_params=pooling_params, diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md index ac4e40221edf..91345e0ae778 100644 --- a/examples/online_serving/pooling/README.md +++ b/examples/online_serving/pooling/README.md @@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py python examples/online_serving/pooling/jinaai_rerank_client.py ``` +## Multi vector retrieval usage + +```bash +python examples/online_serving/pooling/multi_vector_retrieval_client.py +``` + ## Named Entity Recognition (NER) usage ```bash diff --git a/examples/online_serving/pooling/multi_vector_retrieval_client.py b/examples/online_serving/pooling/multi_vector_retrieval_client.py new file mode 100644 index 000000000000..ef8c4745aa53 --- /dev/null +++ b/examples/online_serving/pooling/multi_vector_retrieval_client.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example online usage of Pooling API for multi vector retrieval. + +Run `vllm serve --runner pooling` +to start up the server in vLLM. e.g. + +vllm serve BAAI/bge-m3 +""" + +import argparse + +import requests +import torch + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="BAAI/bge-m3") + + return parser.parse_args() + + +def main(args): + api_url = f"http://{args.host}:{args.port}/pooling" + model_name = args.model + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompt = {"model": model_name, "input": prompts} + + pooling_response = post_http_request(prompt=prompt, api_url=api_url) + for output in pooling_response.json()["data"]: + multi_vector = torch.tensor(output["data"]) + print(multi_vector.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/conftest.py b/tests/conftest.py index 2fde7f97836d..9126b3d668b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1011,8 +1011,12 @@ class VllmRunner: req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] - def encode(self, prompts: list[str]) -> list[list[float]]: - req_outputs = self.llm.encode(prompts) + def token_embed(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts, pooling_task="token_embed") + return [req_output.outputs.data for req_output in req_outputs] + + def token_classify(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts, pooling_task="token_classify") return [req_output.outputs.data for req_output in req_outputs] def reward(self, prompts: list[str]) -> list[list[float]]: diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py index 488c82c9fe7f..96f634ee0a8c 100644 --- a/tests/entrypoints/pooling/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -63,7 +63,7 @@ def test_encode_api(llm: LLM): # chunked prefill does not support all pooling err_msg = "pooling_task must be one of.+" with pytest.raises(ValueError, match=err_msg): - llm.encode(prompts, use_tqdm=False) + llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) def test_score_api(llm: LLM): diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py index c53941390bd1..5455b5f91fc0 100644 --- a/tests/entrypoints/pooling/llm/test_embedding.py +++ b/tests/entrypoints/pooling/llm/test_embedding.py @@ -35,6 +35,13 @@ def llm(): cleanup_dist_env_and_memory() +@pytest.mark.skip_global_cleanup +def test_encode_api(llm: LLM): + outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) + multi_vector = outputs[0].outputs.data + assert multi_vector.shape == (11, 384) + + def test_pooling_params(llm: LLM): def get_outputs(normalize): outputs = llm.embed( diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py index 9ba380334e5a..ca85d2758fce 100644 --- a/tests/entrypoints/pooling/llm/test_encode.py +++ b/tests/entrypoints/pooling/llm/test_encode.py @@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM): ] # Multiple PoolingParams should be matched with each prompt - outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed") assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + outputs = llm.encode( + PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed" + ) # Single PoolingParams should be applied to every prompt single_pooling_params = PoolingParams() - outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + outputs = llm.encode( + PROMPTS, pooling_params=single_pooling_params, pooling_task="embed" + ) assert len(PROMPTS) == len(outputs) # pooling_params is None, default params should be applied - outputs = llm.encode(PROMPTS, pooling_params=None) + outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed") assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py index 8312ff180b36..81058dbad891 100644 --- a/tests/entrypoints/pooling/llm/test_reward.py +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -36,22 +36,23 @@ def llm(): cleanup_dist_env_and_memory() -@pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(softmax): + def get_outputs(activation): outputs = llm.reward( - prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False + prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False ) return torch.cat([x.outputs.data for x in outputs]) - default = get_outputs(softmax=None) - w_softmax = get_outputs(softmax=True) - wo_softmax = get_outputs(softmax=False) + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." - assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), ( - "wo_softmax should not use softmax." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." ) - assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), ( - "w_softmax should be close to softmax(wo_softmax)." + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." ) diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py index 8a3d298a48e2..ab8ca9d68e0e 100644 --- a/tests/entrypoints/pooling/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -17,6 +17,7 @@ from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ( EMBED_DTYPE_TO_TORCH_DTYPE, EmbeddingResponse, + PoolingResponse, ) from vllm.transformers_utils.tokenizer import get_tokenizer @@ -509,3 +510,20 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str): assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( "w_normal should be close to normal(wo_normal)." ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + response = requests.post( + server.url_for("pooling"), + json={"model": model_name, "input": input_text, "encoding_format": "float"}, + ) + + poolings = PoolingResponse.model_validate(response.json()) + + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 11 + assert len(poolings.data[0].data[0]) == 384 diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py index 9980fcff16c1..e43148d25fee 100644 --- a/tests/entrypoints/pooling/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F from tests.utils import RemoteOpenAIServer -from vllm.entrypoints.openai.protocol import RerankResponse +from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" @@ -159,3 +159,20 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str): assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( "w_activation should be close to activation(wo_activation)." ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + response = requests.post( + server.url_for("pooling"), + json={"model": model_name, "input": input_text, "encoding_format": "float"}, + ) + + poolings = PoolingResponse.model_validate(response.json()) + + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 11 + assert len(poolings.data[0].data[0]) == 1 diff --git a/tests/models/language/pooling/test_multi_vector_retrieval.py b/tests/models/language/pooling/test_multi_vector_retrieval.py new file mode 100644 index 000000000000..302f2df13557 --- /dev/null +++ b/tests/models/language/pooling/test_multi_vector_retrieval.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModel + +from tests.models.utils import check_embeddings_close + + +@pytest.mark.parametrize( + "model", + ["BAAI/bge-m3"], +) +@pytest.mark.parametrize("dtype", ["half"]) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str): + with vllm_runner( + model, + runner="pooling", + max_model_len=None, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed(example_prompts) + + with hf_runner( + model, + auto_cls=AutoModel, + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + embedding = output.last_hidden_state[0].float() + # normal + hf_outputs.append(embedding.cpu()) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + check_embeddings_close( + embeddings_0_lst=hf_output, + embeddings_1_lst=vllm_output, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/language/pooling/test_pooler_config_init_behaviour.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py index 674bf02b7b98..55663ee3f1b4 100644 --- a/tests/models/language/pooling/test_pooler_config_init_behaviour.py +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -93,7 +93,7 @@ def test_embed_models_using_normalize( ], ) @pytest.mark.parametrize("dtype", ["half"]) -def test_reward_models_using_softmax( +def test_reward_models_using_activation( hf_runner, vllm_runner, example_prompts, @@ -104,22 +104,64 @@ def test_reward_models_using_softmax( model, max_model_len=1024, dtype=dtype, - pooler_config=PoolerConfig(softmax=False), + pooler_config=PoolerConfig(activation=False), ) as vllm_model: - wo_softmax = vllm_model.encode(example_prompts) + wo_activation = vllm_model.reward(example_prompts) with vllm_runner( - model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True) + model, + max_model_len=1024, + dtype=dtype, + pooler_config=PoolerConfig(activation=True), ) as vllm_model: - w_softmax = vllm_model.encode(example_prompts) + w_activation = vllm_model.reward(example_prompts) - for wo, w in zip(wo_softmax, w_softmax): + for wo, w in zip(wo_activation, w_activation): wo = torch.tensor(wo) w = torch.tensor(w) assert not torch.allclose(wo, w, atol=1e-2), ( - "pooler_config softmax is not working" + "pooler_config activation is not working" ) assert torch.allclose(softmax(wo), w, atol=1e-2), ( - "w_softmax should be close to softmax(wo_softmax)." + "w_activation should be close to activation(wo_activation)." + ) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/multilingual-e5-small", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_multi_vector_retrieval_models_using_normalize( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=False), + ) as vllm_model: + wo_normalize = vllm_model.token_embed(example_prompts) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=True), + ) as vllm_model: + w_normalize = vllm_model.token_embed(example_prompts) + + for wo, w in zip(wo_normalize, w_normalize): + assert not torch.allclose(wo, w, atol=1e-2), ( + "pooler_config normalize is not working" + ) + assert torch.allclose(F.normalize(wo, p=2, dim=-1), w, atol=1e-2), ( + "w_normal should be close to normal(wo_normal)." ) diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index 784d9fc31267..2dfc0072126b 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -19,7 +19,7 @@ def test_bert_models( dtype: str, ) -> None: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.token_classify(example_prompts) with hf_runner( model, dtype=dtype, auto_cls=AutoModelForTokenClassification @@ -50,7 +50,7 @@ def test_modernbert_models( dtype: str, ) -> None: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.token_classify(example_prompts) with hf_runner( model, dtype=dtype, auto_cls=AutoModelForTokenClassification diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index abf4150a9132..62154b083487 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -39,7 +39,7 @@ def _run_test( max_num_seqs=32, default_torch_num_threads=1, ) as vllm_model: - vllm_model.encode(prompt) + vllm_model.llm.encode(prompt, pooling_task="token_classify") MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index d1dae587d38e..98245cdf0c98 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -30,7 +30,7 @@ class MyGemma2Embedding(nn.Module): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 912b32755e80..936f27fb69bc 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -93,7 +93,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): out_data_format="b64_json", ) - pooling_params = PoolingParams(task="encode", softmax=False) + pooling_params = PoolingParams(activation=False) with vllm_runner( model_name, @@ -108,8 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): io_processor_plugin="prithvi_to_tiff", ) as llm_runner: pooler_output = llm_runner.get_llm().encode( - img_prompt, - pooling_params=pooling_params, + img_prompt, pooling_params=pooling_params, pooling_task="token_classify" ) output = pooler_output[0].outputs diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index e3561ac3a577..e73d7efc1483 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + import pytest from tests.models.utils import EmbedModelInfo from vllm import PoolingParams -from vllm.config import ModelConfig +from vllm.config import ModelConfig, PoolerConfig EMBEDDING_MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), @@ -15,6 +17,15 @@ EMBEDDING_MODELS = [ ), ] +classify_parameters = ["activation"] +embed_parameters = ["dimensions", "normalize"] +step_pooling_parameters = ["step_tag_id", "returned_token_ids"] + + +@dataclass() +class MockModelConfig: + pooler_config: PoolerConfig + def test_task(): pooling_params = PoolingParams() @@ -24,25 +35,27 @@ def test_task(): pooling_params.verify(task="score") with pytest.raises(ValueError): - pooling_params.verify(task="encode") + pooling_params.verify(task="classify") def test_embed(): task = "embed" + model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) + pooling_params = PoolingParams(normalize=None) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(normalize=True) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(normalize=False) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) - invalid_parameters = ["activation", "softmax"] + invalid_parameters = classify_parameters + step_pooling_parameters for p in invalid_parameters: with pytest.raises(ValueError): pooling_params = PoolingParams(**{p: True}) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo): @pytest.mark.parametrize("task", ["score", "classify"]) def test_classify(task): + model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) + pooling_params = PoolingParams(activation=None) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(activation=True) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(activation=False) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) - invalid_parameters = ["dimensions", "normalize", "softmax"] + invalid_parameters = embed_parameters + step_pooling_parameters for p in invalid_parameters: with pytest.raises(ValueError): pooling_params = PoolingParams(**{p: True}) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) -def test_encode(): - task = "encode" - pooling_params = PoolingParams(softmax=None) - pooling_params.verify(task=task) +@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"]) +def test_token_embed(pooling_type: str): + task = "token_embed" + model_config = MockModelConfig( + pooler_config=PoolerConfig(pooling_type=pooling_type) + ) - pooling_params = PoolingParams(softmax=True) - pooling_params.verify(task=task) + pooling_params = PoolingParams(normalize=None) + pooling_params.verify(task=task, model_config=model_config) - pooling_params = PoolingParams(softmax=False) - pooling_params.verify(task=task) + pooling_params = PoolingParams(normalize=True) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(normalize=False) + pooling_params.verify(task=task, model_config=model_config) + + invalid_parameters = classify_parameters + if pooling_type != "STEP": + invalid_parameters = classify_parameters + step_pooling_parameters - invalid_parameters = ["dimensions", "normalize", "activation"] for p in invalid_parameters: with pytest.raises(ValueError): pooling_params = PoolingParams(**{p: True}) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) + + +@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"]) +def test_token_classify(pooling_type: str): + task = "token_classify" + model_config = MockModelConfig( + pooler_config=PoolerConfig(pooling_type=pooling_type) + ) + + pooling_params = PoolingParams(activation=None) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(activation=True) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(activation=False) + pooling_params.verify(task=task, model_config=model_config) + + invalid_parameters = embed_parameters + if pooling_type != "STEP": + invalid_parameters = embed_parameters + step_pooling_parameters + + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task, model_config=model_config) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 61376736d0f7..e2db9d049a75 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -951,7 +951,7 @@ class LLM: truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: list[LoRARequest] | LoRARequest | None = None, - pooling_task: PoolingTask = "encode", + pooling_task: PoolingTask | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input @@ -986,25 +986,24 @@ class LLM: instead pass them via the `inputs` parameter. """ - if self.supported_tasks == ["encode"] and pooling_task is None: - pooling_task = "encode" + error_str = ( + "pooling_task required for `LLM.encode`\n" + "Please use one of the more specific methods or set the " + "pooling_task when using `LLM.encode`:\n" + " - For embeddings, use `LLM.embed(...)` " + 'or `pooling_task="embed"`.\n' + " - For classification logits, use `LLM.classify(...)` " + 'or `pooling_task="classify"`.\n' + " - For similarity scores, use `LLM.score(...)`.\n" + " - For rewards, use `LLM.reward(...)` " + 'or `pooling_task="token_classify"`\n' + " - For token classification, " + 'use `pooling_task="token_classify"`\n' + ' - For multi-vector retrieval, use `pooling_task="token_embed"`' + ) if pooling_task is None: - pooling_task = "embed" if "embed" in self.supported_tasks else "encode" - - logger.warning_once( - "`LLM.encode` is currently using `pooling_task = %s`.\n" - "Please use one of the more specific methods or set the " - "task directly when using `LLM.encode`:\n" - " - For embeddings, use `LLM.embed(...)` " - 'or `pooling_task="embed"`.\n' - " - For classification logits, use `LLM.classify(...)` " - 'or `pooling_task="classify"`.\n' - " - For rewards, use `LLM.reward(...)` " - 'or `pooling_task="reward"`\n' - " - For similarity scores, use `LLM.score(...)`.", - pooling_task, - ) + raise ValueError(error_str) model_config = self.model_config runner_type = model_config.runner_type @@ -1206,7 +1205,7 @@ class LLM: lora_request=lora_request, pooling_params=pooling_params, truncate_prompt_tokens=truncate_prompt_tokens, - pooling_task="encode", + pooling_task="token_classify", ) def _embedding_score( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fd80ba7a9afc..0ac035595690 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1748,16 +1748,19 @@ async def init_app_state( else None ) state.openai_serving_pooling = ( - OpenAIServingPooling( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, + ( + OpenAIServingPooling( + engine_client, + state.openai_serving_models, + supported_tasks=supported_tasks, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) ) - if "encode" in supported_tasks + if ("token_embed" in supported_tasks or "token_classify" in supported_tasks) else None ) state.openai_serving_embedding = ( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 86e1e62ff437..5b8a118280da 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): When using plugins IOProcessor plugins, the actual input is processed by the plugin itself. Hence, we use a generic type for the request data """ - softmax: bool = True + activation: bool = False embed_dtype: str = Field( default="float32", @@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ) def to_pooling_params(self): - return PoolingParams(task="encode", softmax=self.softmax) + return PoolingParams(task="token_classify", activation=self.activation) class IOProcessorResponse(OpenAIBaseModel, Generic[T]): diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 3ed17abe0946..aa81a233b297 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.tasks import SupportedTask from vllm.utils import merge_async_iterators logger = init_logger(__name__) @@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing): engine_client: EngineClient, models: OpenAIServingModels, *, + supported_tasks: tuple[SupportedTask, ...], request_logger: RequestLogger | None, chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, @@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing): log_error_stack=log_error_stack, ) + self.supported_tasks = supported_tasks self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format self.trust_request_chat_template = trust_request_chat_template @@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing): try: pooling_params = request.to_pooling_params() + if "token_embed" in self.supported_tasks: + pooling_task = "token_embed" + elif "token_classify" in self.supported_tasks: + pooling_task = "token_classify" + else: + return self.create_error_response( + f"pooling_task must be one of {self.supported_tasks}." + ) + try: - pooling_params.verify("encode", self.model_config) + pooling_params.verify(pooling_task, self.model_config) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 010c607bcabf..84e176f0ea89 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -64,66 +64,6 @@ class PoolingParamsUpdate: params.requires_token_ids = self.requires_token_ids -class Pooler(nn.Module, ABC): - """The interface required for all poolers used in pooling models in vLLM.""" - - @staticmethod - def for_encode(pooler_config: PoolerConfig): - if pooler_config.pooling_type == "STEP": - return StepPooler() - - resolved_config = ResolvedPoolingConfig( - task="encode", pooling_type=PoolingType.ALL - ) - - return SimplePooler.from_config(resolved_config) - - @staticmethod - def for_embed(pooler_config: PoolerConfig): - resolved_config = ResolvedPoolingConfig.from_config( - task="embed", - pooler_config=pooler_config, - ) - - return SimplePooler.from_config(resolved_config) - - @staticmethod - def for_classify( - pooler_config: PoolerConfig, - classifier: ClassifierFn | None, - ): - resolved_config = ResolvedPoolingConfig.from_config( - task="classify", - pooler_config=pooler_config, - ) - - pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) - - return ClassifierPooler( - pooling=pooling, - classifier=classifier, - ) - - @abstractmethod - def get_supported_tasks(self) -> Set[PoolingTask]: - """Determine which pooling tasks are supported.""" - raise NotImplementedError - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - """ - Construct the updated pooling parameters to use for a supported task. - """ - return PoolingParamsUpdate() - - @abstractmethod - def forward( - self, - hidden_states: list[torch.Tensor] | torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - raise NotImplementedError - - def get_prompt_lens( hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, @@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC): class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed", "classify", "score"} + return {"token_embed", "token_classify", "embed", "classify", "score"} def forward_all( self, @@ -253,7 +193,7 @@ class CLSPool(PoolingMethod): class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed", "classify", "score"} + return {"token_embed", "token_classify", "embed", "classify", "score"} def forward_all( self, @@ -265,7 +205,7 @@ class LastPool(PoolingMethod): class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode"} + return {"token_embed", "token_classify"} def forward_all( self, @@ -284,7 +224,7 @@ class AllPool(PoolingMethod): class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed", "classify", "score"} + return {"token_embed", "token_classify", "embed", "classify", "score"} def forward_all( self, @@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation): return self.fn(pooled_data) +class Pooler(nn.Module, ABC): + """The interface required for all poolers used in pooling models in vLLM.""" + + @staticmethod + def for_token_embed(pooler_config: PoolerConfig): + head = TokenEmbeddingPoolerHead() + + if pooler_config.pooling_type == "STEP": + return StepPooler(head=head) + + return AllPooler(head=head) + + @staticmethod + def for_token_classify( + pooler_config: PoolerConfig, + classifier: ClassifierFn | None = None, + act_fn: PoolerActivation | str | None = None, + ): + head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) + + if pooler_config.pooling_type == "STEP": + return StepPooler(head=head) + + return AllPooler(head=head) + + @staticmethod + def for_embed(pooler_config: PoolerConfig): + resolved_config = ResolvedPoolingConfig.from_config( + task="embed", + pooler_config=pooler_config, + ) + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) + head = EmbeddingPoolerHead() + + return SimplePooler(pooling=pooling, head=head) + + @staticmethod + def for_classify( + pooler_config: PoolerConfig, + classifier: ClassifierFn | None, + act_fn: PoolerActivation | str | None = None, + ): + resolved_config = ResolvedPoolingConfig.from_config( + task="classify", + pooler_config=pooler_config, + ) + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) + + return ClassifierPooler( + pooling=pooling, + classifier=classifier, + act_fn=act_fn, + ) + + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + """Determine which pooling tasks are supported.""" + raise NotImplementedError + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + """ + Construct the updated pooling parameters to use for a supported task. + """ + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: list[torch.Tensor] | torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + class PoolerHead(nn.Module): def __init__(self, activation: PoolerActivation) -> None: super().__init__() @@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead): super().__init__(activation=PoolerNormalize()) # Load ST projector if available - vllm_config = get_current_vllm_config() self.projector: nn.Module | None = ( _load_st_projector(vllm_config.model_config) if vllm_config else None @@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead): return pooled_data -class RewardPoolerHead(PoolerHead): - def __init__(self) -> None: - super().__init__(activation=PoolerClassify(static_num_labels=False)) - - vllm_config = get_current_vllm_config() - self.head_dtype = vllm_config.model_config.head_dtype - - def forward( - self, - pooled_data: list[torch.Tensor] | torch.Tensor, - pooling_metadata: PoolingMetadata, - ): - if isinstance(pooled_data, list): - pooled_data = [p.to(self.head_dtype) for p in pooled_data] - else: - pooled_data = pooled_data.to(self.head_dtype) - - pooling_params = get_pooling_params(pooling_metadata) - - # for softmax - flags = [p.softmax for p in pooling_params] - if len(set(flags)) == 1: - if flags[0]: - pooled_data = self.activation(pooled_data) - else: - pooled_data = [ - self.activation(vecs) if f else vecs - for vecs, f in zip(pooled_data, flags) - ] - - return pooled_data - - class SimplePooler(Pooler): """A layer that pools specific information from hidden states. @@ -513,20 +495,6 @@ class SimplePooler(Pooler): 3. Returns structured results as `PoolerOutput`. """ - @classmethod - def from_config( - cls, - pooler_config: ResolvedPoolingConfig, - ) -> "SimplePooler": - pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) - if pooler_config.task == "embed": - head = EmbeddingPoolerHead() - elif pooler_config.task == "encode": - head = RewardPoolerHead() - else: - raise NotImplementedError(f"Unknown task: {pooler_config.task}") - return cls(pooling, head) - def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: super().__init__() @@ -549,58 +517,6 @@ class SimplePooler(Pooler): return pooled_data -class StepPooler(Pooler): - def __init__( - self, - ) -> None: - super().__init__() - - self.pooling = AllPool() - self.head = RewardPoolerHead() - - def extract_states( - self, - hidden_states: torch.Tensor | list[torch.Tensor], - pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor] | torch.Tensor: - pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = get_prompt_token_ids(pooling_metadata) - - pooled_data = list[torch.Tensor]() - - pooling_params = get_pooling_params(pooling_metadata) - - for data, token_id, pooling_param in zip( - pooled_data_lst, prompt_token_ids, pooling_params - ): - step_tag_id = pooling_param.step_tag_id - returned_token_ids = pooling_param.returned_token_ids - - if returned_token_ids is not None and len(returned_token_ids) > 0: - data = data[:, returned_token_ids] - - if step_tag_id is not None: - data = data[token_id == step_tag_id] - pooled_data.append(data) - - return pooled_data - - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode"} - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return PoolingParamsUpdate(requires_token_ids=True) - - def forward( - self, - hidden_states: torch.Tensor | list[torch.Tensor], - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - return pooled_data - - class ClassifierPooler(Pooler): """A pooling layer for classification tasks. @@ -611,26 +527,46 @@ class ClassifierPooler(Pooler): """ @staticmethod - def act_fn_for_seq_cls(config: ModelConfig): - return get_classification_activation_function(config.hf_config) + def act_fn_for_seq_cls(model_config: ModelConfig): + return get_classification_activation_function(model_config.hf_config) @staticmethod - def act_fn_for_cross_encoder(config: ModelConfig): - return get_cross_encoder_activation_function(config.hf_config) + def act_fn_for_cross_encoder(model_config: ModelConfig): + return get_cross_encoder_activation_function(model_config.hf_config) + + @staticmethod + def resolve_act_fn( + model_config: ModelConfig, + static_num_labels: bool = True, + act_fn: PoolerActivation | str | None = None, + ): + if isinstance(act_fn, str): + if act_fn == "classify": + return ClassifierPooler.act_fn_for_seq_cls(model_config) + elif act_fn == "score": + return ClassifierPooler.act_fn_for_cross_encoder(model_config) + else: + raise ValueError(f"act_fn [{act_fn=}] not supported.") + elif act_fn is None: + return PoolerClassify(static_num_labels=static_num_labels) + else: + assert callable(act_fn) + return act_fn def __init__( self, pooling: PoolingFn, classifier: ClassifierFn | None, - act_fn: PoolerActivation | None = None, + act_fn: PoolerActivation | str | None = None, ) -> None: super().__init__() vllm_config = get_current_vllm_config() - self.pooling = pooling self.classifier = classifier - self.act_fn = act_fn or PoolerClassify() + self.act_fn = self.resolve_act_fn( + vllm_config.model_config, static_num_labels=True, act_fn=act_fn + ) self.logit_bias: float | None = ( vllm_config.model_config.pooler_config.logit_bias ) @@ -672,6 +608,150 @@ class ClassifierPooler(Pooler): return scores +class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): + def forward( + self, pooled_data: torch.Tensor, pooling_param: PoolingParams + ) -> torch.Tensor: + pooled_data = pooled_data.to(self.head_dtype) + # pooled_data shape: [n_tokens, hidden_dimension] + + # Apply ST projector + if self.projector is not None: + pooled_data = self.projector(pooled_data) + # pooled_data shape: [n_tokens, embedding_dimension] + + # for matryoshka representation + pooled_data = pooled_data[..., : pooling_param.dimensions] + + # for normalize + if pooling_param.normalize: + pooled_data = self.activation(pooled_data) + + # pooled_data shape: [n_tokens, embedding_dimension] + return pooled_data + + +class TokenClassifierPoolerHead(nn.Module): + def __init__( + self, + classifier: ClassifierFn | None, + act_fn: PoolerActivation | str | None = None, + ) -> None: + super().__init__() + vllm_config = get_current_vllm_config() + + self.classifier = classifier + self.act_fn = ClassifierPooler.resolve_act_fn( + vllm_config.model_config, static_num_labels=False, act_fn=act_fn + ) + self.logit_bias: float | None = ( + vllm_config.model_config.pooler_config.logit_bias + ) + self.head_dtype = vllm_config.model_config.head_dtype + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_classify"} + + def forward( + self, + hidden_states: torch.Tensor, + pooling_param: PoolingParams, + ) -> torch.Tensor: + hidden_states = hidden_states.to(self.head_dtype) + # hidden_states shape: [n_token, hidden_size] + + if self.classifier is not None: + scores = self.classifier(hidden_states) + else: + scores = hidden_states + # scores shape: [n_token, num_labels] + + if self.logit_bias is not None: + scores -= self.logit_bias + + if pooling_param.activation: + scores = self.act_fn(scores) + + # scores shape: [n_token, num_labels] + return scores + + +class AllPooler(Pooler): + def __init__(self, head: nn.Module | PoolerHead) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = head + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed", "token_classify"} + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooling_params = get_pooling_params(pooling_metadata) + assert len(pooled_data) == len(pooling_params) + + pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] + return pooled_data + + +class StepPooler(Pooler): + def __init__(self, head: nn.Module | PoolerHead) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = head + + def extract_states( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor | list[torch.Tensor]: + pooled_data_lst = self.pooling(hidden_states, pooling_metadata) + prompt_token_ids = get_prompt_token_ids(pooling_metadata) + + pooled_data = list[torch.Tensor]() + + pooling_params = get_pooling_params(pooling_metadata) + + for data, token_id, pooling_param in zip( + pooled_data_lst, prompt_token_ids, pooling_params + ): + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) + + return pooled_data + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed", "token_classify"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooling_params = get_pooling_params(pooling_metadata) + assert len(pooled_data) == len(pooling_params) + + pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] + return pooled_data + + class DispatchPooler(Pooler): """Dispatches calls to a sub-pooler based on the pooling task.""" diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1d3874b16484..5d51cd375741 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T: self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), }, ) @@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.pooler import ( - ClassifierPooler, DispatchPooler, Pooler, - PoolingMethod, - PoolingType, ) from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors @@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T: model_config.hidden_size, config.num_labels, bias=False, - params_dtype=torch.float32, + params_dtype=vllm_config.model_config.head_dtype, quant_config=quant_config, + return_bias=False, prefix=maybe_prefix(prefix, "score"), ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - pooling_type_str = pooler_config.pooling_type - assert pooling_type_str is not None - pooling_type = PoolingType[pooling_type_str] - self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), - "classify": ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score ), - "score": ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + "classify": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" ), } ) - def _classifier(self, x: torch.Tensor): - x, _ = self.score(x.float()) - return x - def forward( self, input_ids: torch.Tensor, @@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T: assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + { + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config + ) + } ) ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6e81eb8dc91b..1c2334a78543 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: return DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) @@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): return DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": SPLADESparsePooler( mlm_head=self.mlm_head, cls_token_id=cls_id, @@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( pooling=self.bert.pooler, classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + act_fn="classify", ), "score": ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=self.bert.pooler, classifier=self.classifier, act_fn="score" ), } ) @@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config + ), } ) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 49111dd9ffab..31fdc4d21245 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( pooling=self.new.pooler, classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + act_fn="classify", ), "score": ClassifierPooler( - pooling=self.new.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=self.new.pooler, classifier=self.classifier, act_fn="score" ), } ) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 3d7b28af8bdb..27953c27188d 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index ddd6e53b4a43..6d99d02a32be 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), - "classify": Pooler.for_classify(pooler_config, classifier=self.score), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score + ), + "classify": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" + ), } ) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index ede3e34881b1..181c4ed2dca5 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM): if pooler_config is not None: self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": GritLMPooler(vllm_config.model_config), } ) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 8d83a1478dff..c5bbd5497a14 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + {"token_classify": Pooler.for_token_classify(pooler_config)} ) def forward( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 49cb9311a786..f8a87cf6965f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score + ), "classify": Pooler.for_classify( - pooler_config, - classifier=self.score, + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" ), } ) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index a9333155243d..05a40837954d 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -97,9 +97,15 @@ class JinaVLForSequenceClassification( self.score = JinaVLScorer(vllm_config.model_config) self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), - "classify": Pooler.for_classify(pooler_config, classifier=self.score), - "score": Pooler.for_classify(pooler_config, classifier=self.score), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score + ), + "classify": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" + ), } ) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index ff9f6a41ab99..5dbf38c69086 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + pooling=self.pooling, classifier=self.classifier, act_fn="classify" ), "score": ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=self.pooling, classifier=self.classifier, act_fn="score" ), } ) @@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config + ), } ) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index c2f2ba637f09..e2ba0e262cf7 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + {"token_classify": Pooler.for_token_classify(pooler_config)} ) @@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)}) + self.pooler = DispatchPooler( + {"token_classify": Pooler.for_token_classify(pooler_config)} + ) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 456226360b91..cfccb904f46c 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module): @default_pooling_type("CLS") class RobertaEmbeddingModel(BertEmbeddingModel): - """A model that uses Roberta to provide embedding functionalities. - - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + """A model that uses Roberta to provide embedding functionalities.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + pooling=CLSPool(), classifier=self.classifier, act_fn="classify" ), "score": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=CLSPool(), classifier=self.classifier, act_fn="score" ), } ) diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index e8506666db5b..0252705c62b1 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + {"token_classify": Pooler.for_token_classify(pooler_config)} ) def get_input_embeddings( diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py index 411fb92e9460..7ddeb403da44 100644 --- a/vllm/model_executor/models/transformers_pooling.py +++ b/vllm/model_executor/models/transformers_pooling.py @@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) @@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + pooling=CLSPool(), classifier=self.classifier, act_fn="classify" ), "score": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=CLSPool(), classifier=self.classifier, act_fn="score" ), } ) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 175a4ac01b83..c6dff6e01c1d 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind from vllm.tasks import PoolingTask if TYPE_CHECKING: - from vllm.config import ModelConfig + from vllm.config import ModelConfig, PoolerConfig class PoolingParams( @@ -30,7 +30,6 @@ class PoolingParams( if model support matryoshka representation. activation: Whether to apply activation function to the classification outputs. - softmax: Whether to apply softmax to the reward outputs. """ # --8<-- [start:common-pooling-params] @@ -48,32 +47,19 @@ class PoolingParams( activation: bool | None = None # --8<-- [end:classification-pooling-params] - ## for reward models - softmax: bool | None = None + ## for step pooling models step_tag_id: int | None = None returned_token_ids: list[int] | None = None + ## Internal use only task: PoolingTask | None = None - """Internal use only.""" - requires_token_ids: bool = False - """Internal use only.""" - extra_kwargs: dict[str, Any] | None = None - """Internal use only.""" - output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @property def all_parameters(self) -> list[str]: - return [ - "dimensions", - "normalize", - "activation", - "softmax", - "step_tag_id", - "returned_token_ids", - ] + return ["dimensions", "normalize", "activation"] @property def valid_parameters(self): @@ -81,7 +67,8 @@ class PoolingParams( "embed": ["dimensions", "normalize"], "classify": ["activation"], "score": ["activation"], - "encode": ["softmax", "step_tag_id", "returned_token_ids"], + "token_embed": ["dimensions", "normalize"], + "token_classify": ["activation"], } def clone(self) -> "PoolingParams": @@ -100,7 +87,6 @@ class PoolingParams( # NOTE: Task validation needs to done against the model instance, # which is not available in model config. So, it's not included # in this method - self._merge_default_parameters(model_config) self._set_default_parameters(model_config) self._verify_valid_parameters() @@ -125,8 +111,34 @@ class PoolingParams( if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) + self._verify_step_pooling(pooler_config, valid_parameters) + + def _verify_step_pooling( + self, pooler_config: "PoolerConfig", valid_parameters: list[str] + ): + step_pooling_parameters = ["step_tag_id", "returned_token_ids"] + if pooler_config.pooling_type != "STEP": + invalid_parameters = [] + for k in step_pooling_parameters: + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"Task {self.task} only supports {valid_parameters} " + f"parameters, does not support " + f"{invalid_parameters} parameters" + ) + else: + for k in step_pooling_parameters: + if getattr(pooler_config, k, None) is None: + continue + + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + def _set_default_parameters(self, model_config: Optional["ModelConfig"]): - if self.task == "embed": + if self.task in ["embed", "token_embed"]: if self.normalize is None: self.normalize = True @@ -150,13 +162,9 @@ class PoolingParams( elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") - elif self.task in ["classify", "score"]: + elif self.task in ["classify", "score", "token_classify"]: if self.activation is None: self.activation = True - - elif self.task == "encode": - if self.softmax is None: - self.softmax = True else: raise ValueError(f"Unknown pooling task: {self.task}") @@ -185,7 +193,6 @@ class PoolingParams( f"normalize={self.normalize}, " f"dimensions={self.dimensions}, " f"activation={self.activation}, " - f"softmax={self.softmax}, " f"step_tag_id={self.step_tag_id}, " f"returned_token_ids={self.returned_token_ids}, " f"requires_token_ids={self.requires_token_ids}, " diff --git a/vllm/tasks.py b/vllm/tasks.py index 85c5c6e43620..6551444d1710 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -5,7 +5,7 @@ from typing import Literal, get_args GenerationTask = Literal["generate", "transcription"] GENERATION_TASKS = get_args(GenerationTask) -PoolingTask = Literal["encode", "embed", "classify", "score"] +PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"] POOLING_TASKS = get_args(PoolingTask) SupportedTask = Literal[GenerationTask, PoolingTask] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d995a609318c..9e394dbb592e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1926,15 +1926,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): supported_tasks = list(model.pooler.get_supported_tasks()) - if ( - self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks - ): - supported_tasks.remove("encode") + if self.scheduler_config.chunked_prefill_enabled: + if "token_embed" in supported_tasks: + supported_tasks.remove("token_embed") + if "token_classify" in supported_tasks: + supported_tasks.remove("token_classify") logger.debug_once( "Chunked prefill is not supported with " - "encode task which using ALL pooling. " + "token_embed and token_classify tasks " + "which using ALL pooling. " "Please turn off chunked prefill by " "`--no-enable-chunked-prefill` before using it." )