mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:05:02 +08:00
[Frontend][Doc][5/N] Improve all pooling task | Polish encode (pooling) api & Document. (#25524)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
74374386e2
commit
4464723f22
@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen
|
|||||||
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
||||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
||||||
|
|
||||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
||||||
|
|
||||||
## Using an IO Processor plugin
|
## Using an IO Processor plugin
|
||||||
|
|
||||||
|
|||||||
@ -30,11 +30,11 @@ If `--runner pooling` has been set (manually or automatically) but the model doe
|
|||||||
vLLM will attempt to automatically convert the model according to the architecture names
|
vLLM will attempt to automatically convert the model according to the architecture names
|
||||||
shown in the table below.
|
shown in the table below.
|
||||||
|
|
||||||
| Architecture | `--convert` | Supported pooling tasks |
|
| Architecture | `--convert` | Supported pooling tasks |
|
||||||
|-------------------------------------------------|-------------|-------------------------------|
|
|-------------------------------------------------|-------------|---------------------------------------|
|
||||||
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `encode`, `embed` |
|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
|
||||||
| `*For*Classification`, `*ClassificationModel` | `classify` | `encode`, `classify`, `score` |
|
| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
|
||||||
| `*ForRewardModeling`, `*RewardModel` | `reward` | `encode` |
|
| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` |
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
You can explicitly set `--convert <type>` to specify how to convert the model.
|
You can explicitly set `--convert <type>` to specify how to convert the model.
|
||||||
@ -45,12 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
|
|||||||
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
|
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
|
||||||
enabling the corresponding APIs:
|
enabling the corresponding APIs:
|
||||||
|
|
||||||
| Task | APIs |
|
| Task | APIs |
|
||||||
|------------|--------------------------------------|
|
|------------------|-------------------------------------------------------------------------------|
|
||||||
| `encode` | `LLM.reward(...)` |
|
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\*, `LLM.encode(..., pooling_task="embed")` |
|
||||||
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
|
| `classify` | `LLM.classify(...)`, `LLM.encode(..., pooling_task="classify")` |
|
||||||
| `classify` | `LLM.classify(...)` |
|
| `score` | `LLM.score(...)` |
|
||||||
| `score` | `LLM.score(...)` |
|
| `token_classify` | `LLM.reward(...)`, `LLM.encode(..., pooling_task="token_classify")` |
|
||||||
|
| `token_embed` | `LLM.encode(..., pooling_task="token_embed")` |
|
||||||
|
| `plugin` | `LLM.encode(..., pooling_task="plugin")` |
|
||||||
|
|
||||||
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
|
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
|
||||||
|
|
||||||
@ -144,7 +146,6 @@ A code example can be found here: [examples/offline_inference/basic/score.py](..
|
|||||||
### `LLM.reward`
|
### `LLM.reward`
|
||||||
|
|
||||||
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
|
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
|
||||||
It returns the extracted hidden states directly.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
@ -161,15 +162,17 @@ A code example can be found here: [examples/offline_inference/basic/reward.py](.
|
|||||||
### `LLM.encode`
|
### `LLM.encode`
|
||||||
|
|
||||||
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
|
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
|
||||||
It returns the extracted hidden states directly.
|
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
|
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
|
||||||
|
|
||||||
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
|
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
|
||||||
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
|
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
|
||||||
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
|
|
||||||
- For similarity scores, use `LLM.score(...)`.
|
- For similarity scores, use `LLM.score(...)`.
|
||||||
|
- For rewards, use `LLM.reward(...)` or `pooling_task="token_classify"`.
|
||||||
|
- For token classification, use `pooling_task="token_classify"`.
|
||||||
|
- For multi-vector retrieval, use `pooling_task="token_embed"`
|
||||||
|
- For IO Processor Plugins , use `pooling_task="plugin"`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
@ -185,10 +188,47 @@ print(f"Data: {data!r}")
|
|||||||
|
|
||||||
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
|
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
|
||||||
|
|
||||||
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
|
|
||||||
- [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models.
|
- [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models.
|
||||||
- [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
|
- [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
|
||||||
- [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models.
|
- [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models.
|
||||||
|
- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Please use one of the more specific methods or set the task directly when using [Pooling API](../serving/openai_compatible_server.md#pooling-api) api.:
|
||||||
|
|
||||||
|
- For embeddings, use [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) or `"task":"embed"`.
|
||||||
|
- For classification logits, use [Classification API](../serving/openai_compatible_server.md#classification-api) or `task":"classify"`.
|
||||||
|
- For similarity scores, use [Score API](../serving/openai_compatible_server.md#score-api).
|
||||||
|
- For rewards, `task":"token_classify"`.
|
||||||
|
- For token classification, use `task":"token_classify"`.
|
||||||
|
- For multi-vector retrieval, use `task":"token_embed"`
|
||||||
|
- For IO Processor Plugins , use `task":"plugin"`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# start a supported embeddings model server with `vllm serve`, e.g.
|
||||||
|
# vllm serve intfloat/e5-small
|
||||||
|
import requests
|
||||||
|
|
||||||
|
host = "localhost"
|
||||||
|
port = "8000"
|
||||||
|
model_name = "intfloat/e5-small"
|
||||||
|
|
||||||
|
api_url = f"http://{host}:{port}/pooling"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
prompt = {"model": model_name, "input": prompts, "task": "embed"}
|
||||||
|
|
||||||
|
response = requests.post(api_url, json=prompt)
|
||||||
|
|
||||||
|
for output in response.json()["data"]:
|
||||||
|
data = output["data"]
|
||||||
|
print(f"Data: {data!r} (size={len(data)})")
|
||||||
|
```
|
||||||
|
|
||||||
## Matryoshka Embeddings
|
## Matryoshka Embeddings
|
||||||
|
|
||||||
@ -265,3 +305,16 @@ Expected output:
|
|||||||
```
|
```
|
||||||
|
|
||||||
An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py)
|
An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py)
|
||||||
|
|
||||||
|
## Deprecated Features
|
||||||
|
|
||||||
|
### Encode task
|
||||||
|
|
||||||
|
We have split the `encode` task into two more specific token wise tasks: `token_embed` and `token_classify`:
|
||||||
|
|
||||||
|
- `token_embed` is the same as embed, using normalize as activation.
|
||||||
|
- `token_classify` is the same as classify, default using softmax as activation.
|
||||||
|
|
||||||
|
### Remove softmax from PoolingParams
|
||||||
|
|
||||||
|
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, you should set `use_activation`, since we actually allow `classify` and `token_classify` to use any activation function.
|
||||||
|
|||||||
@ -638,7 +638,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent
|
|||||||
|
|
||||||
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||||
|
|
||||||
Code example: [examples/online_serving/openai_cross_encoder_score.py](../../examples/online_serving/openai_cross_encoder_score.py)
|
Code example: [examples/online_serving/pooling/openai_cross_encoder_score.py](../../examples/online_serving/pooling/openai_cross_encoder_score.py)
|
||||||
|
|
||||||
#### Single inference
|
#### Single inference
|
||||||
|
|
||||||
@ -819,7 +819,7 @@ You can pass multi-modal inputs to scoring models by passing `content` including
|
|||||||
print("Scoring output:", response_json["data"][0]["score"])
|
print("Scoring output:", response_json["data"][0]["score"])
|
||||||
print("Scoring output:", response_json["data"][1]["score"])
|
print("Scoring output:", response_json["data"][1]["score"])
|
||||||
```
|
```
|
||||||
Full example: [examples/online_serving/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/openai_cross_encoder_score_for_multimodal.py)
|
Full example: [examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py)
|
||||||
|
|
||||||
#### Extra parameters
|
#### Extra parameters
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,18 @@ python examples/offline_inference/pooling/multi_vector_retrieval.py
|
|||||||
python examples/offline_inference/pooling/ner.py
|
python examples/offline_inference/pooling/ner.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Prithvi Geospatial MAE usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/offline_inference/pooling/prithvi_geospatial_mae.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## IO Processor Plugins for Prithvi Geospatial MAE
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py
|
||||||
|
```
|
||||||
|
|
||||||
## Qwen3 reranker usage
|
## Qwen3 reranker usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@ -33,7 +33,7 @@ def main(args: Namespace):
|
|||||||
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
|
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
outputs = llm.encode(prompts)
|
outputs = llm.encode(prompts, pooling_task="token_classify")
|
||||||
|
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
logits = output.outputs.data
|
logits = output.outputs.data
|
||||||
|
|||||||
@ -3,65 +3,95 @@
|
|||||||
## Cohere rerank usage
|
## Cohere rerank usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve BAAI/bge-reranker-base
|
||||||
python examples/online_serving/pooling/cohere_rerank_client.py
|
python examples/online_serving/pooling/cohere_rerank_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Embedding requests base64 encoding_format usage
|
## Embedding requests base64 encoding_format usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve intfloat/e5-small
|
||||||
python examples/online_serving/pooling/embedding_requests_base64_client.py
|
python examples/online_serving/pooling/embedding_requests_base64_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Embedding requests bytes encoding_format usage
|
## Embedding requests bytes encoding_format usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve intfloat/e5-small
|
||||||
python examples/online_serving/pooling/embedding_requests_bytes_client.py
|
python examples/online_serving/pooling/embedding_requests_bytes_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Jinaai rerank usage
|
## Jinaai rerank usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve BAAI/bge-reranker-base
|
||||||
python examples/online_serving/pooling/jinaai_rerank_client.py
|
python examples/online_serving/pooling/jinaai_rerank_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Multi vector retrieval usage
|
## Multi vector retrieval usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve BAAI/bge-m3
|
||||||
python examples/online_serving/pooling/multi_vector_retrieval_client.py
|
python examples/online_serving/pooling/multi_vector_retrieval_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Named Entity Recognition (NER) usage
|
## Named Entity Recognition (NER) usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve boltuix/NeuroBERT-NER
|
||||||
python examples/online_serving/pooling/ner_client.py
|
python examples/online_serving/pooling/ner_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Openai chat embedding for multimodal usage
|
## OpenAI chat embedding for multimodal usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
|
python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Openai classification usage
|
## OpenAI classification usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve jason9693/Qwen2.5-1.5B-apeach
|
||||||
python examples/online_serving/pooling/openai_classification_client.py
|
python examples/online_serving/pooling/openai_classification_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Openai embedding usage
|
## OpenAI cross_encoder score usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve BAAI/bge-reranker-v2-m3
|
||||||
|
python examples/online_serving/pooling/openai_cross_encoder_score.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAI cross_encoder score for multimodal usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# vllm serve jinaai/jina-reranker-m0
|
||||||
|
python examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAI embedding usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# vllm serve intfloat/e5-small
|
||||||
python examples/online_serving/pooling/openai_embedding_client.py
|
python examples/online_serving/pooling/openai_embedding_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Openai embedding matryoshka dimensions usage
|
## OpenAI embedding matryoshka dimensions usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
|
||||||
python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py
|
python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Openai pooling usage
|
## OpenAI pooling usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
|
||||||
python examples/online_serving/pooling/openai_pooling_client.py
|
python examples/online_serving/pooling/openai_pooling_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Online Prithvi Geospatial MAE usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/online_serving/pooling/prithvi_geospatial_mae.py
|
||||||
|
```
|
||||||
|
|||||||
@ -37,15 +37,17 @@ def llm():
|
|||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_pooling_params(llm: LLM):
|
def test_pooling_params(llm: LLM):
|
||||||
def get_outputs(activation):
|
def get_outputs(use_activation):
|
||||||
outputs = llm.classify(
|
outputs = llm.classify(
|
||||||
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
|
prompts,
|
||||||
|
pooling_params=PoolingParams(use_activation=use_activation),
|
||||||
|
use_tqdm=False,
|
||||||
)
|
)
|
||||||
return torch.tensor([x.outputs.probs for x in outputs])
|
return torch.tensor([x.outputs.probs for x in outputs])
|
||||||
|
|
||||||
default = get_outputs(activation=None)
|
default = get_outputs(use_activation=None)
|
||||||
w_activation = get_outputs(activation=True)
|
w_activation = get_outputs(use_activation=True)
|
||||||
wo_activation = get_outputs(activation=False)
|
wo_activation = get_outputs(use_activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
"Default should use activation."
|
"Default should use activation."
|
||||||
|
|||||||
@ -37,15 +37,17 @@ def llm():
|
|||||||
|
|
||||||
|
|
||||||
def test_pooling_params(llm: LLM):
|
def test_pooling_params(llm: LLM):
|
||||||
def get_outputs(activation):
|
def get_outputs(use_activation):
|
||||||
outputs = llm.reward(
|
outputs = llm.reward(
|
||||||
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
|
prompts,
|
||||||
|
pooling_params=PoolingParams(use_activation=use_activation),
|
||||||
|
use_tqdm=False,
|
||||||
)
|
)
|
||||||
return torch.cat([x.outputs.data for x in outputs])
|
return torch.cat([x.outputs.data for x in outputs])
|
||||||
|
|
||||||
default = get_outputs(activation=None)
|
default = get_outputs(use_activation=None)
|
||||||
w_activation = get_outputs(activation=True)
|
w_activation = get_outputs(use_activation=True)
|
||||||
wo_activation = get_outputs(activation=False)
|
wo_activation = get_outputs(use_activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
"Default should use activation."
|
"Default should use activation."
|
||||||
|
|||||||
@ -34,21 +34,21 @@ def llm():
|
|||||||
|
|
||||||
|
|
||||||
def test_pooling_params(llm: LLM):
|
def test_pooling_params(llm: LLM):
|
||||||
def get_outputs(activation):
|
def get_outputs(use_activation):
|
||||||
text_1 = "What is the capital of France?"
|
text_1 = "What is the capital of France?"
|
||||||
text_2 = "The capital of France is Paris."
|
text_2 = "The capital of France is Paris."
|
||||||
|
|
||||||
outputs = llm.score(
|
outputs = llm.score(
|
||||||
text_1,
|
text_1,
|
||||||
text_2,
|
text_2,
|
||||||
pooling_params=PoolingParams(activation=activation),
|
pooling_params=PoolingParams(use_activation=use_activation),
|
||||||
use_tqdm=False,
|
use_tqdm=False,
|
||||||
)
|
)
|
||||||
return torch.tensor([x.outputs.score for x in outputs])
|
return torch.tensor([x.outputs.score for x in outputs])
|
||||||
|
|
||||||
default = get_outputs(activation=None)
|
default = get_outputs(use_activation=None)
|
||||||
w_activation = get_outputs(activation=True)
|
w_activation = get_outputs(use_activation=True)
|
||||||
wo_activation = get_outputs(activation=False)
|
wo_activation = get_outputs(use_activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
"Default should use activation."
|
"Default should use activation."
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import ClassificationResponse
|
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse
|
||||||
|
|
||||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||||
DTYPE = "float32" # Use float32 to avoid NaN issue
|
DTYPE = "float32" # Use float32 to avoid NaN issue
|
||||||
@ -163,20 +163,24 @@ async def test_invocations(server: RemoteOpenAIServer):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
|
||||||
input_text = ["This product was excellent and exceeded my expectations"]
|
input_text = ["This product was excellent and exceeded my expectations"]
|
||||||
|
|
||||||
async def get_outputs(activation):
|
async def get_outputs(use_activation):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
server.url_for("classify"),
|
server.url_for("classify"),
|
||||||
json={"model": model_name, "input": input_text, "activation": activation},
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": input_text,
|
||||||
|
"use_activation": use_activation,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
outputs = response.json()
|
outputs = response.json()
|
||||||
return torch.tensor([x["probs"] for x in outputs["data"]])
|
return torch.tensor([x["probs"] for x in outputs["data"]])
|
||||||
|
|
||||||
default = await get_outputs(activation=None)
|
default = await get_outputs(use_activation=None)
|
||||||
w_activation = await get_outputs(activation=True)
|
w_activation = await get_outputs(use_activation=True)
|
||||||
wo_activation = await get_outputs(activation=False)
|
wo_activation = await get_outputs(use_activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
"Default should use activation."
|
"Default should use activation."
|
||||||
@ -191,18 +195,7 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
async def test_score(server: RemoteOpenAIServer, model_name: str):
|
||||||
# pooling api uses ALL pooling, which does not support chunked prefill.
|
|
||||||
response = requests.post(
|
|
||||||
server.url_for("pooling"),
|
|
||||||
json={"model": model_name, "input": "test", "encoding_format": "float"},
|
|
||||||
)
|
|
||||||
assert response.json()["error"]["type"] == "BadRequestError"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
||||||
def test_score(server: RemoteOpenAIServer, model_name: str):
|
|
||||||
# score api is only enabled for num_labels == 1.
|
# score api is only enabled for num_labels == 1.
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
server.url_for("score"),
|
server.url_for("score"),
|
||||||
@ -217,7 +210,7 @@ def test_score(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
async def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||||
# rerank api is only enabled for num_labels == 1.
|
# rerank api is only enabled for num_labels == 1.
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
server.url_for("rerank"),
|
server.url_for("rerank"),
|
||||||
@ -228,3 +221,62 @@ def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.json()["error"]["type"] == "BadRequestError"
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
input_text = "This product was excellent and exceeded my expectations"
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": input_text,
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": "classify",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
|
assert len(poolings.data) == 1
|
||||||
|
assert len(poolings.data[0].data) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
# token_classify uses ALL pooling, which does not support chunked prefill.
|
||||||
|
task = "token_classify"
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": "test",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
|
assert response.json()["error"]["message"].startswith(
|
||||||
|
f"Task {task} is not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||||
|
async def test_pooling_not_supported(
|
||||||
|
server: RemoteOpenAIServer, model_name: str, task: str
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": "test",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
|
assert response.json()["error"]["message"].startswith(
|
||||||
|
f"Task {task} is not supported"
|
||||||
|
)
|
||||||
|
|||||||
@ -562,12 +562,40 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
task = "embed"
|
||||||
input_text = ["The chef prepared a delicious meal."]
|
input_text = ["The chef prepared a delicious meal."]
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
server.url_for("pooling"),
|
server.url_for("pooling"),
|
||||||
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": input_text,
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
assert len(poolings.data) == 1
|
||||||
|
assert len(poolings.data[0].data) == 384
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
task = "token_embed"
|
||||||
|
input_text = ["The chef prepared a delicious meal."]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": input_text,
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
poolings = PoolingResponse.model_validate(response.json())
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
@ -575,3 +603,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
assert len(poolings.data) == 1
|
assert len(poolings.data) == 1
|
||||||
assert len(poolings.data[0].data) == 11
|
assert len(poolings.data[0].data) == 11
|
||||||
assert len(poolings.data[0].data[0]) == 384
|
assert len(poolings.data[0].data[0]) == 384
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
|
||||||
|
async def test_pooling_not_supported(
|
||||||
|
server: RemoteOpenAIServer, model_name: str, task: str
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": "test",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
|
assert response.json()["error"]["message"].startswith(
|
||||||
|
f"Task {task} is not supported"
|
||||||
|
)
|
||||||
|
|||||||
@ -125,8 +125,8 @@ def test_invocations(server: RemoteOpenAIServer):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
|
||||||
async def get_outputs(activation):
|
async def get_outputs(use_activation):
|
||||||
query = "What is the capital of France?"
|
query = "What is the capital of France?"
|
||||||
documents = [
|
documents = [
|
||||||
"The capital of Brazil is Brasilia.",
|
"The capital of Brazil is Brasilia.",
|
||||||
@ -139,16 +139,16 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
|||||||
"model": model_name,
|
"model": model_name,
|
||||||
"query": query,
|
"query": query,
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
"activation": activation,
|
"use_activation": use_activation,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
outputs = response.json()
|
outputs = response.json()
|
||||||
|
|
||||||
return torch.tensor([x["relevance_score"] for x in outputs["results"]])
|
return torch.tensor([x["relevance_score"] for x in outputs["results"]])
|
||||||
|
|
||||||
default = await get_outputs(activation=None)
|
default = await get_outputs(use_activation=None)
|
||||||
w_activation = await get_outputs(activation=True)
|
w_activation = await get_outputs(use_activation=True)
|
||||||
wo_activation = await get_outputs(activation=False)
|
wo_activation = await get_outputs(use_activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
"Default should use activation."
|
"Default should use activation."
|
||||||
@ -163,7 +163,25 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
input_text = "This product was excellent and exceeded my expectations"
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": input_text,
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": "classify",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
|
assert len(poolings.data) == 1
|
||||||
|
assert len(poolings.data[0].data) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||||
input_text = ["The chef prepared a delicious meal."]
|
input_text = ["The chef prepared a delicious meal."]
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@ -176,3 +194,24 @@ async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
assert len(poolings.data) == 1
|
assert len(poolings.data) == 1
|
||||||
assert len(poolings.data[0].data) == 11
|
assert len(poolings.data[0].data) == 11
|
||||||
assert len(poolings.data[0].data[0]) == 1
|
assert len(poolings.data[0].data[0]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||||
|
async def test_pooling_not_supported(
|
||||||
|
server: RemoteOpenAIServer, model_name: str, task: str
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"input": "test",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"task": task,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
|
assert response.json()["error"]["message"].startswith(
|
||||||
|
f"Task {task} is not supported"
|
||||||
|
)
|
||||||
|
|||||||
@ -218,8 +218,8 @@ class TestModel:
|
|||||||
# TODO: reset this tolerance to 0.01 once we find
|
# TODO: reset this tolerance to 0.01 once we find
|
||||||
# an alternative to flash_attn with bfloat16
|
# an alternative to flash_attn with bfloat16
|
||||||
|
|
||||||
def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]):
|
def test_use_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]):
|
||||||
def get_outputs(activation):
|
def get_outputs(use_activation):
|
||||||
text_1 = "What is the capital of France?"
|
text_1 = "What is the capital of France?"
|
||||||
text_2 = "The capital of France is Paris."
|
text_2 = "The capital of France is Paris."
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@ -228,7 +228,7 @@ class TestModel:
|
|||||||
"model": model["name"],
|
"model": model["name"],
|
||||||
"text_1": text_1,
|
"text_1": text_1,
|
||||||
"text_2": text_2,
|
"text_2": text_2,
|
||||||
"activation": activation,
|
"use_activation": use_activation,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@ -238,9 +238,9 @@ class TestModel:
|
|||||||
return torch.tensor([x["score"] for x in outputs["data"]])
|
return torch.tensor([x["score"] for x in outputs["data"]])
|
||||||
|
|
||||||
if model["is_cross_encoder"]:
|
if model["is_cross_encoder"]:
|
||||||
default = get_outputs(activation=None)
|
default = get_outputs(use_activation=None)
|
||||||
w_activation = get_outputs(activation=True)
|
w_activation = get_outputs(use_activation=True)
|
||||||
wo_activation = get_outputs(activation=False)
|
wo_activation = get_outputs(use_activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_activation, atol=1e-2), (
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
"Default should use activation."
|
"Default should use activation."
|
||||||
@ -252,8 +252,8 @@ class TestModel:
|
|||||||
"w_activation should be close to activation(wo_activation)."
|
"w_activation should be close to activation(wo_activation)."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
get_outputs(activation=None)
|
get_outputs(use_activation=None)
|
||||||
|
|
||||||
# The activation parameter only works for the is_cross_encoder model
|
# The activation parameter only works for the is_cross_encoder model
|
||||||
response = get_outputs(activation=True)
|
response = get_outputs(use_activation=True)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def test_classify_models_using_activation(
|
|||||||
model,
|
model,
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pooler_config=PoolerConfig(activation=False),
|
pooler_config=PoolerConfig(use_activation=False),
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
wo_activation_out = vllm_model.classify(example_prompts)
|
wo_activation_out = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ def test_classify_models_using_activation(
|
|||||||
model,
|
model,
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pooler_config=PoolerConfig(activation=True),
|
pooler_config=PoolerConfig(use_activation=True),
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
w_activation_out = vllm_model.classify(example_prompts)
|
w_activation_out = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ def test_reward_models_using_activation(
|
|||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pooler_config=PoolerConfig(activation=False),
|
pooler_config=PoolerConfig(use_activation=False),
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
wo_activation = vllm_model.reward(example_prompts)
|
wo_activation = vllm_model.reward(example_prompts)
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ def test_reward_models_using_activation(
|
|||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pooler_config=PoolerConfig(activation=True),
|
pooler_config=PoolerConfig(use_activation=True),
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
w_activation = vllm_model.reward(example_prompts)
|
w_activation = vllm_model.reward(example_prompts)
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ EMBEDDING_MODELS = [
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
classify_parameters = ["activation"]
|
classify_parameters = ["use_activation"]
|
||||||
embed_parameters = ["dimensions", "normalize"]
|
embed_parameters = ["dimensions", "normalize"]
|
||||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||||
|
|
||||||
@ -88,13 +88,13 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
|||||||
def test_classify(task):
|
def test_classify(task):
|
||||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=None)
|
pooling_params = PoolingParams(use_activation=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=True)
|
pooling_params = PoolingParams(use_activation=True)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=False)
|
pooling_params = PoolingParams(use_activation=False)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||||
@ -137,13 +137,13 @@ def test_token_classify(pooling_type: str):
|
|||||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=None)
|
pooling_params = PoolingParams(use_activation=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=True)
|
pooling_params = PoolingParams(use_activation=True)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=False)
|
pooling_params = PoolingParams(use_activation=False)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
invalid_parameters = embed_parameters
|
invalid_parameters = embed_parameters
|
||||||
|
|||||||
@ -7,6 +7,9 @@ from typing import Any
|
|||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -48,7 +51,15 @@ class PoolerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
## for classification models
|
## for classification models
|
||||||
activation: bool | None = None
|
softmax: float | None = None
|
||||||
|
"""
|
||||||
|
softmax will be deprecated, please use use_activation instead.
|
||||||
|
"""
|
||||||
|
activation: float | None = None
|
||||||
|
"""
|
||||||
|
activation will be deprecated, please use use_activation instead.
|
||||||
|
"""
|
||||||
|
use_activation: bool | None = None
|
||||||
"""
|
"""
|
||||||
Whether to apply activation function to the classification outputs.
|
Whether to apply activation function to the classification outputs.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -59,11 +70,6 @@ class PoolerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
## for reward models
|
## for reward models
|
||||||
softmax: bool | None = None
|
|
||||||
"""
|
|
||||||
Whether to apply softmax to the reward outputs.
|
|
||||||
Defaults to True.
|
|
||||||
"""
|
|
||||||
step_tag_id: int | None = None
|
step_tag_id: int | None = None
|
||||||
"""
|
"""
|
||||||
If set, only the score corresponding to the `step_tag_id` in the
|
If set, only the score corresponding to the `step_tag_id` in the
|
||||||
@ -77,6 +83,10 @@ class PoolerConfig:
|
|||||||
`math-shepherd-mistral-7b-prm` model.
|
`math-shepherd-mistral-7b-prm` model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# raise deprecated warning for softmax and activation
|
||||||
|
self.use_activation = get_use_activation(self)
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
@ -94,3 +104,19 @@ class PoolerConfig:
|
|||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_use_activation(o: object):
|
||||||
|
if softmax := getattr(o, "softmax", None) is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"softmax will be deprecated, please use use_activation instead."
|
||||||
|
)
|
||||||
|
return softmax
|
||||||
|
|
||||||
|
if activation := getattr(o, "activation", None) is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"activation will be deprecated, please use use_activation instead."
|
||||||
|
)
|
||||||
|
return activation
|
||||||
|
|
||||||
|
return getattr(o, "use_activation", None)
|
||||||
|
|||||||
@ -107,6 +107,7 @@ from vllm.entrypoints.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
from vllm.tasks import POOLING_TASKS
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||||
@ -1748,12 +1749,7 @@ async def init_app_state(
|
|||||||
log_error_stack=args.log_error_stack,
|
log_error_stack=args.log_error_stack,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if (
|
if any(task in POOLING_TASKS for task in supported_tasks)
|
||||||
any(
|
|
||||||
task in supported_tasks
|
|
||||||
for task in ["token_embed", "token_classify", "plugin"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
state.openai_serving_embedding = (
|
state.openai_serving_embedding = (
|
||||||
|
|||||||
@ -49,6 +49,8 @@ from openai.types.responses.response_reasoning_item import (
|
|||||||
)
|
)
|
||||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||||
|
|
||||||
|
from vllm.config.pooler import get_use_activation
|
||||||
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.utils.serial_utils import (
|
from vllm.utils.serial_utils import (
|
||||||
EmbedDType,
|
EmbedDType,
|
||||||
EncodingFormat,
|
EncodingFormat,
|
||||||
@ -1669,8 +1671,58 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||||
|
|
||||||
PoolingCompletionRequest = EmbeddingCompletionRequest
|
|
||||||
PoolingChatRequest = EmbeddingChatRequest
|
class PoolingCompletionRequest(EmbeddingCompletionRequest):
|
||||||
|
task: PoolingTask | None = None
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"If it is a classify or token_classify task, the default is True; "
|
||||||
|
"for other tasks, this value should be None.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingChatRequest(EmbeddingChatRequest):
|
||||||
|
task: PoolingTask | None = None
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"If it is a classify or token_classify task, the default is True; "
|
||||||
|
"for other tasks, this value should be None.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -1686,6 +1738,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
data: T
|
data: T
|
||||||
|
|
||||||
|
task: PoolingTask = "plugin"
|
||||||
encoding_format: EncodingFormat = "float"
|
encoding_format: EncodingFormat = "float"
|
||||||
embed_dtype: EmbedDType = Field(
|
embed_dtype: EmbedDType = Field(
|
||||||
default="float32",
|
default="float32",
|
||||||
@ -1749,14 +1802,27 @@ class ScoreRequest(OpenAIBaseModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
activation: bool | None = None
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
# --8<-- [end:score-extra-params]
|
# --8<-- [end:score-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
activation=self.activation,
|
use_activation=get_use_activation(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1783,14 +1849,27 @@ class RerankRequest(OpenAIBaseModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
activation: bool | None = None
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
# --8<-- [end:rerank-extra-params]
|
# --8<-- [end:rerank-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
activation=self.activation,
|
use_activation=get_use_activation(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1958,14 +2037,27 @@ class ClassificationRequest(OpenAIBaseModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
activation: bool | None = None
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
# --8<-- [end:classification-extra-params]
|
# --8<-- [end:classification-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
activation=self.activation,
|
use_activation=get_use_activation(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -170,15 +170,24 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
pooling_task: PoolingTask
|
pooling_task: PoolingTask
|
||||||
if "token_embed" in self.supported_tasks:
|
if request.task is None:
|
||||||
pooling_task = "token_embed"
|
if "token_embed" in self.supported_tasks:
|
||||||
elif "token_classify" in self.supported_tasks:
|
pooling_task = "token_embed"
|
||||||
pooling_task = "token_classify"
|
elif "token_classify" in self.supported_tasks:
|
||||||
elif "plugin" in self.supported_tasks:
|
pooling_task = "token_classify"
|
||||||
pooling_task = "plugin"
|
elif "plugin" in self.supported_tasks:
|
||||||
|
pooling_task = "plugin"
|
||||||
|
else:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"pooling_task must be one of {self.supported_tasks}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
pooling_task = request.task
|
||||||
|
|
||||||
|
if pooling_task not in self.supported_tasks:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
f"pooling_task must be one of {self.supported_tasks}."
|
f"Task {pooling_task} is not supported, it"
|
||||||
|
f" must be one of {self.supported_tasks}."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -607,7 +607,7 @@ class ClassifierPooler(Pooler):
|
|||||||
pooled_data -= self.logit_bias
|
pooled_data -= self.logit_bias
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
flags = [p.activation for p in pooling_params]
|
flags = [p.use_activation for p in pooling_params]
|
||||||
|
|
||||||
if len(set(flags)) == 1:
|
if len(set(flags)) == 1:
|
||||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||||
@ -681,7 +681,7 @@ class TokenClassifierPoolerHead(nn.Module):
|
|||||||
if self.logit_bias is not None:
|
if self.logit_bias is not None:
|
||||||
scores -= self.logit_bias
|
scores -= self.logit_bias
|
||||||
|
|
||||||
if pooling_param.activation:
|
if pooling_param.use_activation:
|
||||||
scores = self.act_fn(scores)
|
scores = self.act_fn(scores)
|
||||||
|
|
||||||
# scores shape: [n_token, num_labels]
|
# scores shape: [n_token, num_labels]
|
||||||
|
|||||||
@ -53,8 +53,8 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
if pooler_config.activation is None:
|
if pooler_config.use_activation is None:
|
||||||
pooler_config.activation = False
|
pooler_config.use_activation = False
|
||||||
|
|
||||||
|
|
||||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|||||||
@ -2,16 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, PoolerConfig
|
||||||
|
from vllm.config.pooler import get_use_activation
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingParams(
|
class PoolingParams(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
@ -25,10 +24,12 @@ class PoolingParams(
|
|||||||
Set to -1 to use the model's default truncation size.
|
Set to -1 to use the model's default truncation size.
|
||||||
Set to k to keep only the last k tokens (left truncation).
|
Set to k to keep only the last k tokens (left truncation).
|
||||||
Set to None to disable truncation.
|
Set to None to disable truncation.
|
||||||
normalize: Whether to normalize the embeddings outputs.
|
|
||||||
dimensions: Reduce the dimensions of embeddings
|
dimensions: Reduce the dimensions of embeddings
|
||||||
if model support matryoshka representation.
|
if model support matryoshka representation.
|
||||||
activation: Whether to apply activation function to
|
normalize: Whether to normalize the embeddings outputs.
|
||||||
|
softmax: softmax will be deprecated, please use use_activation instead.
|
||||||
|
activation: activation will be deprecated, please use use_activation instead.
|
||||||
|
use_activation: Whether to apply activation function to
|
||||||
the classification outputs.
|
the classification outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -44,7 +45,9 @@ class PoolingParams(
|
|||||||
|
|
||||||
## for classification, scoring and rerank
|
## for classification, scoring and rerank
|
||||||
# --8<-- [start:classification-pooling-params]
|
# --8<-- [start:classification-pooling-params]
|
||||||
|
softmax: bool | None = None
|
||||||
activation: bool | None = None
|
activation: bool | None = None
|
||||||
|
use_activation: bool | None = None
|
||||||
# --8<-- [end:classification-pooling-params]
|
# --8<-- [end:classification-pooling-params]
|
||||||
|
|
||||||
## for step pooling models
|
## for step pooling models
|
||||||
@ -59,16 +62,16 @@ class PoolingParams(
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def all_parameters(self) -> list[str]:
|
def all_parameters(self) -> list[str]:
|
||||||
return ["dimensions", "normalize", "activation"]
|
return ["dimensions", "normalize", "use_activation"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_parameters(self):
|
def valid_parameters(self):
|
||||||
return {
|
return {
|
||||||
"embed": ["dimensions", "normalize"],
|
"embed": ["dimensions", "normalize"],
|
||||||
"classify": ["activation"],
|
"classify": ["use_activation"],
|
||||||
"score": ["activation"],
|
"score": ["use_activation"],
|
||||||
"token_embed": ["dimensions", "normalize"],
|
"token_embed": ["dimensions", "normalize"],
|
||||||
"token_classify": ["activation"],
|
"token_classify": ["use_activation"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def clone(self) -> "PoolingParams":
|
def clone(self) -> "PoolingParams":
|
||||||
@ -84,6 +87,9 @@ class PoolingParams(
|
|||||||
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# raise deprecated warning for softmax and activation
|
||||||
|
self.use_activation = get_use_activation(self)
|
||||||
|
|
||||||
# plugin task uses io_processor.parse_request to verify inputs,
|
# plugin task uses io_processor.parse_request to verify inputs,
|
||||||
# skipping PoolingParams verify
|
# skipping PoolingParams verify
|
||||||
if self.task == "plugin":
|
if self.task == "plugin":
|
||||||
@ -168,8 +174,8 @@ class PoolingParams(
|
|||||||
raise ValueError("Dimensions must be greater than 0")
|
raise ValueError("Dimensions must be greater than 0")
|
||||||
|
|
||||||
elif self.task in ["classify", "score", "token_classify"]:
|
elif self.task in ["classify", "score", "token_classify"]:
|
||||||
if self.activation is None:
|
if self.use_activation is None:
|
||||||
self.activation = True
|
self.use_activation = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pooling task: {self.task}")
|
raise ValueError(f"Unknown pooling task: {self.task}")
|
||||||
|
|
||||||
@ -197,7 +203,7 @@ class PoolingParams(
|
|||||||
f"task={self.task}, "
|
f"task={self.task}, "
|
||||||
f"normalize={self.normalize}, "
|
f"normalize={self.normalize}, "
|
||||||
f"dimensions={self.dimensions}, "
|
f"dimensions={self.dimensions}, "
|
||||||
f"activation={self.activation}, "
|
f"use_activation={self.use_activation}, "
|
||||||
f"step_tag_id={self.step_tag_id}, "
|
f"step_tag_id={self.step_tag_id}, "
|
||||||
f"returned_token_ids={self.returned_token_ids}, "
|
f"returned_token_ids={self.returned_token_ids}, "
|
||||||
f"requires_token_ids={self.requires_token_ids}, "
|
f"requires_token_ids={self.requires_token_ids}, "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user