[Frontend][Doc][5/N] Improve all pooling task | Polish encode (pooling) api & Document. (#25524)

Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
wang.yuqi 2025-10-30 20:13:05 +08:00 committed by GitHub
parent 74374386e2
commit 4464723f22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 499 additions and 131 deletions

View File

@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.
## Using an IO Processor plugin ## Using an IO Processor plugin

View File

@ -30,11 +30,11 @@ If `--runner pooling` has been set (manually or automatically) but the model doe
vLLM will attempt to automatically convert the model according to the architecture names vLLM will attempt to automatically convert the model according to the architecture names
shown in the table below. shown in the table below.
| Architecture | `--convert` | Supported pooling tasks | | Architecture | `--convert` | Supported pooling tasks |
|-------------------------------------------------|-------------|-------------------------------| |-------------------------------------------------|-------------|---------------------------------------|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `encode`, `embed` | | `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
| `*For*Classification`, `*ClassificationModel` | `classify` | `encode`, `classify`, `score` | | `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
| `*ForRewardModeling`, `*RewardModel` | `reward` | `encode` | | `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` |
!!! tip !!! tip
You can explicitly set `--convert <type>` to specify how to convert the model. You can explicitly set `--convert <type>` to specify how to convert the model.
@ -45,12 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks], [Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
enabling the corresponding APIs: enabling the corresponding APIs:
| Task | APIs | | Task | APIs |
|------------|--------------------------------------| |------------------|-------------------------------------------------------------------------------|
| `encode` | `LLM.reward(...)` | | `embed` | `LLM.embed(...)`, `LLM.score(...)`\*, `LLM.encode(..., pooling_task="embed")` |
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* | | `classify` | `LLM.classify(...)`, `LLM.encode(..., pooling_task="classify")` |
| `classify` | `LLM.classify(...)` | | `score` | `LLM.score(...)` |
| `score` | `LLM.score(...)` | | `token_classify` | `LLM.reward(...)`, `LLM.encode(..., pooling_task="token_classify")` |
| `token_embed` | `LLM.encode(..., pooling_task="token_embed")` |
| `plugin` | `LLM.encode(..., pooling_task="plugin")` |
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task. \* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
@ -144,7 +146,6 @@ A code example can be found here: [examples/offline_inference/basic/score.py](..
### `LLM.reward` ### `LLM.reward`
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM. The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
It returns the extracted hidden states directly.
```python ```python
from vllm import LLM from vllm import LLM
@ -161,15 +162,17 @@ A code example can be found here: [examples/offline_inference/basic/reward.py](.
### `LLM.encode` ### `LLM.encode`
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
It returns the extracted hidden states directly.
!!! note !!! note
Please use one of the more specific methods or set the task directly when using `LLM.encode`: Please use one of the more specific methods or set the task directly when using `LLM.encode`:
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`. - For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`. - For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
- For similarity scores, use `LLM.score(...)`. - For similarity scores, use `LLM.score(...)`.
- For rewards, use `LLM.reward(...)` or `pooling_task="token_classify"`.
- For token classification, use `pooling_task="token_classify"`.
- For multi-vector retrieval, use `pooling_task="token_embed"`
- For IO Processor Plugins , use `pooling_task="plugin"`
```python ```python
from vllm import LLM from vllm import LLM
@ -185,10 +188,47 @@ print(f"Data: {data!r}")
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
- [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models. - [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models.
- [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models. - [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
- [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models. - [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models.
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
!!! note
Please use one of the more specific methods or set the task directly when using [Pooling API](../serving/openai_compatible_server.md#pooling-api) api.:
- For embeddings, use [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) or `"task":"embed"`.
- For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `task":"classify"`.
- For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api).
- For rewards, `task":"token_classify"`.
- For token classification, use `task":"token_classify"`.
- For multi-vector retrieval, use `task":"token_embed"`
- For IO Processor Plugins , use `task":"plugin"`
```python
# start a supported embeddings model server with `vllm serve`, e.g.
# vllm serve intfloat/e5-small
import requests
host = "localhost"
port = "8000"
model_name = "intfloat/e5-small"
api_url = f"http://{host}:{port}/pooling"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompt = {"model": model_name, "input": prompts, "task": "embed"}
response = requests.post(api_url, json=prompt)
for output in response.json()["data"]:
data = output["data"]
print(f"Data: {data!r} (size={len(data)})")
```
## Matryoshka Embeddings ## Matryoshka Embeddings
@ -265,3 +305,16 @@ Expected output:
``` ```
An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py) An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py)
## Deprecated Features
### Encode task
We have split the `encode` task into two more specific token wise tasks: `token_embed` and `token_classify`:
- `token_embed` is the same as embed, using normalize as activation.
- `token_classify` is the same as classify, default using softmax as activation.
### Remove softmax from PoolingParams
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, you should set `use_activation`, since we actually allow `classify` and `token_classify` to use any activation function.

View File

@ -638,7 +638,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
Code example: [examples/online_serving/openai_cross_encoder_score.py](../../examples/online_serving/openai_cross_encoder_score.py) Code example: [examples/online_serving/pooling/openai_cross_encoder_score.py](../../examples/online_serving/pooling/openai_cross_encoder_score.py)
#### Single inference #### Single inference
@ -819,7 +819,7 @@ You can pass multi-modal inputs to scoring models by passing `content` including
print("Scoring output:", response_json["data"][0]["score"]) print("Scoring output:", response_json["data"][0]["score"])
print("Scoring output:", response_json["data"][1]["score"]) print("Scoring output:", response_json["data"][1]["score"])
``` ```
Full example: [examples/online_serving/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/openai_cross_encoder_score_for_multimodal.py) Full example: [examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py)
#### Extra parameters #### Extra parameters

View File

@ -38,6 +38,18 @@ python examples/offline_inference/pooling/multi_vector_retrieval.py
python examples/offline_inference/pooling/ner.py python examples/offline_inference/pooling/ner.py
``` ```
## Prithvi Geospatial MAE usage
```bash
python examples/offline_inference/pooling/prithvi_geospatial_mae.py
```
## IO Processor Plugins for Prithvi Geospatial MAE
```bash
python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py
```
## Qwen3 reranker usage ## Qwen3 reranker usage
```bash ```bash

View File

@ -33,7 +33,7 @@ def main(args: Namespace):
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
# Run inference # Run inference
outputs = llm.encode(prompts) outputs = llm.encode(prompts, pooling_task="token_classify")
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
logits = output.outputs.data logits = output.outputs.data

View File

@ -3,65 +3,95 @@
## Cohere rerank usage ## Cohere rerank usage
```bash ```bash
# vllm serve BAAI/bge-reranker-base
python examples/online_serving/pooling/cohere_rerank_client.py python examples/online_serving/pooling/cohere_rerank_client.py
``` ```
## Embedding requests base64 encoding_format usage ## Embedding requests base64 encoding_format usage
```bash ```bash
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/embedding_requests_base64_client.py python examples/online_serving/pooling/embedding_requests_base64_client.py
``` ```
## Embedding requests bytes encoding_format usage ## Embedding requests bytes encoding_format usage
```bash ```bash
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/embedding_requests_bytes_client.py python examples/online_serving/pooling/embedding_requests_bytes_client.py
``` ```
## Jinaai rerank usage ## Jinaai rerank usage
```bash ```bash
# vllm serve BAAI/bge-reranker-base
python examples/online_serving/pooling/jinaai_rerank_client.py python examples/online_serving/pooling/jinaai_rerank_client.py
``` ```
## Multi vector retrieval usage ## Multi vector retrieval usage
```bash ```bash
# vllm serve BAAI/bge-m3
python examples/online_serving/pooling/multi_vector_retrieval_client.py python examples/online_serving/pooling/multi_vector_retrieval_client.py
``` ```
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash
# vllm serve boltuix/NeuroBERT-NER
python examples/online_serving/pooling/ner_client.py python examples/online_serving/pooling/ner_client.py
``` ```
## Openai chat embedding for multimodal usage ## OpenAI chat embedding for multimodal usage
```bash ```bash
python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
``` ```
## Openai classification usage ## OpenAI classification usage
```bash ```bash
# vllm serve jason9693/Qwen2.5-1.5B-apeach
python examples/online_serving/pooling/openai_classification_client.py python examples/online_serving/pooling/openai_classification_client.py
``` ```
## Openai embedding usage ## OpenAI cross_encoder score usage
```bash ```bash
# vllm serve BAAI/bge-reranker-v2-m3
python examples/online_serving/pooling/openai_cross_encoder_score.py
```
## OpenAI cross_encoder score for multimodal usage
```bash
# vllm serve jinaai/jina-reranker-m0
python examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py
```
## OpenAI embedding usage
```bash
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/openai_embedding_client.py python examples/online_serving/pooling/openai_embedding_client.py
``` ```
## Openai embedding matryoshka dimensions usage ## OpenAI embedding matryoshka dimensions usage
```bash ```bash
# vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py
``` ```
## Openai pooling usage ## OpenAI pooling usage
```bash ```bash
# vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
python examples/online_serving/pooling/openai_pooling_client.py python examples/online_serving/pooling/openai_pooling_client.py
``` ```
## Online Prithvi Geospatial MAE usage
```bash
python examples/online_serving/pooling/prithvi_geospatial_mae.py
```

View File

@ -37,15 +37,17 @@ def llm():
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(activation): def get_outputs(use_activation):
outputs = llm.classify( outputs = llm.classify(
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False prompts,
pooling_params=PoolingParams(use_activation=use_activation),
use_tqdm=False,
) )
return torch.tensor([x.outputs.probs for x in outputs]) return torch.tensor([x.outputs.probs for x in outputs])
default = get_outputs(activation=None) default = get_outputs(use_activation=None)
w_activation = get_outputs(activation=True) w_activation = get_outputs(use_activation=True)
wo_activation = get_outputs(activation=False) wo_activation = get_outputs(use_activation=False)
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."

View File

@ -37,15 +37,17 @@ def llm():
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(activation): def get_outputs(use_activation):
outputs = llm.reward( outputs = llm.reward(
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False prompts,
pooling_params=PoolingParams(use_activation=use_activation),
use_tqdm=False,
) )
return torch.cat([x.outputs.data for x in outputs]) return torch.cat([x.outputs.data for x in outputs])
default = get_outputs(activation=None) default = get_outputs(use_activation=None)
w_activation = get_outputs(activation=True) w_activation = get_outputs(use_activation=True)
wo_activation = get_outputs(activation=False) wo_activation = get_outputs(use_activation=False)
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."

View File

@ -34,21 +34,21 @@ def llm():
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(activation): def get_outputs(use_activation):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris." text_2 = "The capital of France is Paris."
outputs = llm.score( outputs = llm.score(
text_1, text_1,
text_2, text_2,
pooling_params=PoolingParams(activation=activation), pooling_params=PoolingParams(use_activation=use_activation),
use_tqdm=False, use_tqdm=False,
) )
return torch.tensor([x.outputs.score for x in outputs]) return torch.tensor([x.outputs.score for x in outputs])
default = get_outputs(activation=None) default = get_outputs(use_activation=None)
w_activation = get_outputs(activation=True) w_activation = get_outputs(use_activation=True)
wo_activation = get_outputs(activation=False) wo_activation = get_outputs(use_activation=False)
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."

View File

@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue DTYPE = "float32" # Use float32 to avoid NaN issue
@ -163,20 +163,24 @@ async def test_invocations(server: RemoteOpenAIServer):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_activation(server: RemoteOpenAIServer, model_name: str): async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
input_text = ["This product was excellent and exceeded my expectations"] input_text = ["This product was excellent and exceeded my expectations"]
async def get_outputs(activation): async def get_outputs(use_activation):
response = requests.post( response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": input_text, "activation": activation}, json={
"model": model_name,
"input": input_text,
"use_activation": use_activation,
},
) )
outputs = response.json() outputs = response.json()
return torch.tensor([x["probs"] for x in outputs["data"]]) return torch.tensor([x["probs"] for x in outputs["data"]])
default = await get_outputs(activation=None) default = await get_outputs(use_activation=None)
w_activation = await get_outputs(activation=True) w_activation = await get_outputs(use_activation=True)
wo_activation = await get_outputs(activation=False) wo_activation = await get_outputs(use_activation=False)
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."
@ -191,18 +195,7 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_pooling(server: RemoteOpenAIServer, model_name: str): async def test_score(server: RemoteOpenAIServer, model_name: str):
# pooling api uses ALL pooling, which does not support chunked prefill.
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "input": "test", "encoding_format": "float"},
)
assert response.json()["error"]["type"] == "BadRequestError"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_score(server: RemoteOpenAIServer, model_name: str):
# score api is only enabled for num_labels == 1. # score api is only enabled for num_labels == 1.
response = requests.post( response = requests.post(
server.url_for("score"), server.url_for("score"),
@ -217,7 +210,7 @@ def test_score(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank(server: RemoteOpenAIServer, model_name: str): async def test_rerank(server: RemoteOpenAIServer, model_name: str):
# rerank api is only enabled for num_labels == 1. # rerank api is only enabled for num_labels == 1.
response = requests.post( response = requests.post(
server.url_for("rerank"), server.url_for("rerank"),
@ -228,3 +221,62 @@ def test_rerank(server: RemoteOpenAIServer, model_name: str):
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
input_text = "This product was excellent and exceeded my expectations"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": "classify",
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)

View File

@ -562,12 +562,40 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str): async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
task = "embed"
input_text = ["The chef prepared a delicious meal."] input_text = ["The chef prepared a delicious meal."]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"}, json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 384
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
task = "token_embed"
input_text = ["The chef prepared a delicious meal."]
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
) )
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
@ -575,3 +603,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str):
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11 assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 384 assert len(poolings.data[0].data[0]) == 384
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)

View File

@ -125,8 +125,8 @@ def test_invocations(server: RemoteOpenAIServer):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_activation(server: RemoteOpenAIServer, model_name: str): async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
async def get_outputs(activation): async def get_outputs(use_activation):
query = "What is the capital of France?" query = "What is the capital of France?"
documents = [ documents = [
"The capital of Brazil is Brasilia.", "The capital of Brazil is Brasilia.",
@ -139,16 +139,16 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
"model": model_name, "model": model_name,
"query": query, "query": query,
"documents": documents, "documents": documents,
"activation": activation, "use_activation": use_activation,
}, },
) )
outputs = response.json() outputs = response.json()
return torch.tensor([x["relevance_score"] for x in outputs["results"]]) return torch.tensor([x["relevance_score"] for x in outputs["results"]])
default = await get_outputs(activation=None) default = await get_outputs(use_activation=None)
w_activation = await get_outputs(activation=True) w_activation = await get_outputs(use_activation=True)
wo_activation = await get_outputs(activation=False) wo_activation = await get_outputs(use_activation=False)
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."
@ -163,7 +163,25 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str): async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
input_text = "This product was excellent and exceeded my expectations"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": "classify",
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."] input_text = ["The chef prepared a delicious meal."]
response = requests.post( response = requests.post(
@ -176,3 +194,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str):
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11 assert len(poolings.data[0].data) == 11
assert len(poolings.data[0].data[0]) == 1 assert len(poolings.data[0].data[0]) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)

View File

@ -218,8 +218,8 @@ class TestModel:
# TODO: reset this tolerance to 0.01 once we find # TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16 # an alternative to flash_attn with bfloat16
def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): def test_use_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]):
def get_outputs(activation): def get_outputs(use_activation):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris." text_2 = "The capital of France is Paris."
response = requests.post( response = requests.post(
@ -228,7 +228,7 @@ class TestModel:
"model": model["name"], "model": model["name"],
"text_1": text_1, "text_1": text_1,
"text_2": text_2, "text_2": text_2,
"activation": activation, "use_activation": use_activation,
}, },
) )
if response.status_code != 200: if response.status_code != 200:
@ -238,9 +238,9 @@ class TestModel:
return torch.tensor([x["score"] for x in outputs["data"]]) return torch.tensor([x["score"] for x in outputs["data"]])
if model["is_cross_encoder"]: if model["is_cross_encoder"]:
default = get_outputs(activation=None) default = get_outputs(use_activation=None)
w_activation = get_outputs(activation=True) w_activation = get_outputs(use_activation=True)
wo_activation = get_outputs(activation=False) wo_activation = get_outputs(use_activation=False)
assert torch.allclose(default, w_activation, atol=1e-2), ( assert torch.allclose(default, w_activation, atol=1e-2), (
"Default should use activation." "Default should use activation."
@ -252,8 +252,8 @@ class TestModel:
"w_activation should be close to activation(wo_activation)." "w_activation should be close to activation(wo_activation)."
) )
else: else:
get_outputs(activation=None) get_outputs(use_activation=None)
# The activation parameter only works for the is_cross_encoder model # The activation parameter only works for the is_cross_encoder model
response = get_outputs(activation=True) response = get_outputs(use_activation=True)
assert response.status_code == 400 assert response.status_code == 400

View File

@ -24,7 +24,7 @@ def test_classify_models_using_activation(
model, model,
max_model_len=512, max_model_len=512,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(activation=False), pooler_config=PoolerConfig(use_activation=False),
) as vllm_model: ) as vllm_model:
wo_activation_out = vllm_model.classify(example_prompts) wo_activation_out = vllm_model.classify(example_prompts)
@ -32,7 +32,7 @@ def test_classify_models_using_activation(
model, model,
max_model_len=512, max_model_len=512,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(activation=True), pooler_config=PoolerConfig(use_activation=True),
) as vllm_model: ) as vllm_model:
w_activation_out = vllm_model.classify(example_prompts) w_activation_out = vllm_model.classify(example_prompts)
@ -104,7 +104,7 @@ def test_reward_models_using_activation(
model, model,
max_model_len=1024, max_model_len=1024,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(activation=False), pooler_config=PoolerConfig(use_activation=False),
) as vllm_model: ) as vllm_model:
wo_activation = vllm_model.reward(example_prompts) wo_activation = vllm_model.reward(example_prompts)
@ -112,7 +112,7 @@ def test_reward_models_using_activation(
model, model,
max_model_len=1024, max_model_len=1024,
dtype=dtype, dtype=dtype,
pooler_config=PoolerConfig(activation=True), pooler_config=PoolerConfig(use_activation=True),
) as vllm_model: ) as vllm_model:
w_activation = vllm_model.reward(example_prompts) w_activation = vllm_model.reward(example_prompts)

View File

@ -17,7 +17,7 @@ EMBEDDING_MODELS = [
), ),
] ]
classify_parameters = ["activation"] classify_parameters = ["use_activation"]
embed_parameters = ["dimensions", "normalize"] embed_parameters = ["dimensions", "normalize"]
step_pooling_parameters = ["step_tag_id", "returned_token_ids"] step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
@ -88,13 +88,13 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
pooling_params = PoolingParams(activation=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=True) pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=False) pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters + step_pooling_parameters invalid_parameters = embed_parameters + step_pooling_parameters
@ -137,13 +137,13 @@ def test_token_classify(pooling_type: str):
pooler_config=PoolerConfig(pooling_type=pooling_type) pooler_config=PoolerConfig(pooling_type=pooling_type)
) )
pooling_params = PoolingParams(activation=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=True) pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(activation=False) pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters invalid_parameters = embed_parameters

View File

@ -7,6 +7,9 @@ from typing import Any
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
@config @config
@ -48,7 +51,15 @@ class PoolerConfig:
""" """
## for classification models ## for classification models
activation: bool | None = None softmax: float | None = None
"""
softmax will be deprecated, please use use_activation instead.
"""
activation: float | None = None
"""
activation will be deprecated, please use use_activation instead.
"""
use_activation: bool | None = None
""" """
Whether to apply activation function to the classification outputs. Whether to apply activation function to the classification outputs.
Defaults to True. Defaults to True.
@ -59,11 +70,6 @@ class PoolerConfig:
""" """
## for reward models ## for reward models
softmax: bool | None = None
"""
Whether to apply softmax to the reward outputs.
Defaults to True.
"""
step_tag_id: int | None = None step_tag_id: int | None = None
""" """
If set, only the score corresponding to the `step_tag_id` in the If set, only the score corresponding to the `step_tag_id` in the
@ -77,6 +83,10 @@ class PoolerConfig:
`math-shepherd-mistral-7b-prm` model. `math-shepherd-mistral-7b-prm` model.
""" """
def __post_init__(self):
# raise deprecated warning for softmax and activation
self.use_activation = get_use_activation(self)
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@ -94,3 +104,19 @@ class PoolerConfig:
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
def get_use_activation(o: object):
if softmax := getattr(o, "softmax", None) is not None:
logger.warning_once(
"softmax will be deprecated, please use use_activation instead."
)
return softmax
if activation := getattr(o, "activation", None) is not None:
logger.warning_once(
"activation will be deprecated, please use use_activation instead."
)
return activation
return getattr(o, "use_activation", None)

View File

@ -107,6 +107,7 @@ from vllm.entrypoints.utils import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.network_utils import is_valid_ipv6_address
@ -1748,12 +1749,7 @@ async def init_app_state(
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
) )
if ( if any(task in POOLING_TASKS for task in supported_tasks)
any(
task in supported_tasks
for task in ["token_embed", "token_classify", "plugin"]
)
)
else None else None
) )
state.openai_serving_embedding = ( state.openai_serving_embedding = (

View File

@ -49,6 +49,8 @@ from openai.types.responses.response_reasoning_item import (
) )
from openai_harmony import Message as OpenAIHarmonyMessage from openai_harmony import Message as OpenAIHarmonyMessage
from vllm.config.pooler import get_use_activation
from vllm.tasks import PoolingTask
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EmbedDType, EmbedDType,
EncodingFormat, EncodingFormat,
@ -1669,8 +1671,58 @@ class EmbeddingChatRequest(OpenAIBaseModel):
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest class PoolingCompletionRequest(EmbeddingCompletionRequest):
task: PoolingTask | None = None
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"If it is a classify or token_classify task, the default is True; "
"for other tasks, this value should be None.",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize,
use_activation=get_use_activation(self),
)
class PoolingChatRequest(EmbeddingChatRequest):
task: PoolingTask | None = None
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"If it is a classify or token_classify task, the default is True; "
"for other tasks, this value should be None.",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize,
use_activation=get_use_activation(self),
)
T = TypeVar("T") T = TypeVar("T")
@ -1686,6 +1738,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
""" """
data: T data: T
task: PoolingTask = "plugin"
encoding_format: EncodingFormat = "float" encoding_format: EncodingFormat = "float"
embed_dtype: EmbedDType = Field( embed_dtype: EmbedDType = Field(
default="float32", default="float32",
@ -1749,14 +1802,27 @@ class ScoreRequest(OpenAIBaseModel):
), ),
) )
activation: bool | None = None softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:score-extra-params] # --8<-- [end:score-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation, use_activation=get_use_activation(self),
) )
@ -1783,14 +1849,27 @@ class RerankRequest(OpenAIBaseModel):
), ),
) )
activation: bool | None = None softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:rerank-extra-params] # --8<-- [end:rerank-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation, use_activation=get_use_activation(self),
) )
@ -1958,14 +2037,27 @@ class ClassificationRequest(OpenAIBaseModel):
), ),
) )
activation: bool | None = None softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)
activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)
use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:classification-extra-params] # --8<-- [end:classification-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation, use_activation=get_use_activation(self),
) )

View File

@ -170,15 +170,24 @@ class OpenAIServingPooling(OpenAIServing):
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
pooling_task: PoolingTask pooling_task: PoolingTask
if "token_embed" in self.supported_tasks: if request.task is None:
pooling_task = "token_embed" if "token_embed" in self.supported_tasks:
elif "token_classify" in self.supported_tasks: pooling_task = "token_embed"
pooling_task = "token_classify" elif "token_classify" in self.supported_tasks:
elif "plugin" in self.supported_tasks: pooling_task = "token_classify"
pooling_task = "plugin" elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
else: else:
pooling_task = request.task
if pooling_task not in self.supported_tasks:
return self.create_error_response( return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}." f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
) )
try: try:

View File

@ -607,7 +607,7 @@ class ClassifierPooler(Pooler):
pooled_data -= self.logit_bias pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata) pooling_params = get_pooling_params(pooling_metadata)
flags = [p.activation for p in pooling_params] flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1: if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data scores = self.act_fn(pooled_data) if flags[0] else pooled_data
@ -681,7 +681,7 @@ class TokenClassifierPoolerHead(nn.Module):
if self.logit_bias is not None: if self.logit_bias is not None:
scores -= self.logit_bias scores -= self.logit_bias
if pooling_param.activation: if pooling_param.use_activation:
scores = self.act_fn(scores) scores = self.act_fn(scores)
# scores shape: [n_token, num_labels] # scores shape: [n_token, num_labels]

View File

@ -53,8 +53,8 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
if pooler_config.activation is None: if pooler_config.use_activation is None:
pooler_config.activation = False pooler_config.use_activation = False
class JinaRobertaModelConfig(VerifyAndUpdateConfig): class JinaRobertaModelConfig(VerifyAndUpdateConfig):

View File

@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Annotated, Any, Optional from typing import Annotated, Any, Optional
import msgspec import msgspec
from vllm.config import ModelConfig, PoolerConfig
from vllm.config.pooler import get_use_activation
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
if TYPE_CHECKING:
from vllm.config import ModelConfig, PoolerConfig
class PoolingParams( class PoolingParams(
msgspec.Struct, msgspec.Struct,
@ -25,10 +24,12 @@ class PoolingParams(
Set to -1 to use the model's default truncation size. Set to -1 to use the model's default truncation size.
Set to k to keep only the last k tokens (left truncation). Set to k to keep only the last k tokens (left truncation).
Set to None to disable truncation. Set to None to disable truncation.
normalize: Whether to normalize the embeddings outputs.
dimensions: Reduce the dimensions of embeddings dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation. if model support matryoshka representation.
activation: Whether to apply activation function to normalize: Whether to normalize the embeddings outputs.
softmax: softmax will be deprecated, please use use_activation instead.
activation: activation will be deprecated, please use use_activation instead.
use_activation: Whether to apply activation function to
the classification outputs. the classification outputs.
""" """
@ -44,7 +45,9 @@ class PoolingParams(
## for classification, scoring and rerank ## for classification, scoring and rerank
# --8<-- [start:classification-pooling-params] # --8<-- [start:classification-pooling-params]
softmax: bool | None = None
activation: bool | None = None activation: bool | None = None
use_activation: bool | None = None
# --8<-- [end:classification-pooling-params] # --8<-- [end:classification-pooling-params]
## for step pooling models ## for step pooling models
@ -59,16 +62,16 @@ class PoolingParams(
@property @property
def all_parameters(self) -> list[str]: def all_parameters(self) -> list[str]:
return ["dimensions", "normalize", "activation"] return ["dimensions", "normalize", "use_activation"]
@property @property
def valid_parameters(self): def valid_parameters(self):
return { return {
"embed": ["dimensions", "normalize"], "embed": ["dimensions", "normalize"],
"classify": ["activation"], "classify": ["use_activation"],
"score": ["activation"], "score": ["use_activation"],
"token_embed": ["dimensions", "normalize"], "token_embed": ["dimensions", "normalize"],
"token_classify": ["activation"], "token_classify": ["use_activation"],
} }
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
@ -84,6 +87,9 @@ class PoolingParams(
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg) raise ValueError(msg)
# raise deprecated warning for softmax and activation
self.use_activation = get_use_activation(self)
# plugin task uses io_processor.parse_request to verify inputs, # plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify # skipping PoolingParams verify
if self.task == "plugin": if self.task == "plugin":
@ -168,8 +174,8 @@ class PoolingParams(
raise ValueError("Dimensions must be greater than 0") raise ValueError("Dimensions must be greater than 0")
elif self.task in ["classify", "score", "token_classify"]: elif self.task in ["classify", "score", "token_classify"]:
if self.activation is None: if self.use_activation is None:
self.activation = True self.use_activation = True
else: else:
raise ValueError(f"Unknown pooling task: {self.task}") raise ValueError(f"Unknown pooling task: {self.task}")
@ -197,7 +203,7 @@ class PoolingParams(
f"task={self.task}, " f"task={self.task}, "
f"normalize={self.normalize}, " f"normalize={self.normalize}, "
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"activation={self.activation}, " f"use_activation={self.use_activation}, "
f"step_tag_id={self.step_tag_id}, " f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, " f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, " f"requires_token_ids={self.requires_token_ids}, "