mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:35:01 +08:00
[Frontend] Resettle pooling entrypoints (#29634)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
83805a6078
commit
62de4f4257
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -149,6 +149,7 @@ mkdocs.yaml @hmellor
|
|||||||
/examples/*/pooling/ @noooop
|
/examples/*/pooling/ @noooop
|
||||||
/tests/models/*/pooling* @noooop
|
/tests/models/*/pooling* @noooop
|
||||||
/tests/entrypoints/pooling @noooop
|
/tests/entrypoints/pooling @noooop
|
||||||
|
/vllm/entrypoints/pooling @aarnphm @chaunceyjiang @noooop
|
||||||
/vllm/config/pooler.py @noooop
|
/vllm/config/pooler.py @noooop
|
||||||
/vllm/pooling_params.py @noooop
|
/vllm/pooling_params.py @noooop
|
||||||
/vllm/model_executor/layers/pooler.py @noooop
|
/vllm/model_executor/layers/pooler.py @noooop
|
||||||
|
|||||||
@ -77,7 +77,7 @@ The `parse_request` method is used for validating the user prompt and converting
|
|||||||
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
||||||
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
||||||
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
||||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
|
||||||
|
|
||||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
||||||
|
|
||||||
|
|||||||
@ -351,7 +351,7 @@ The following extra parameters are supported by default:
|
|||||||
??? code
|
??? code
|
||||||
|
|
||||||
```python
|
```python
|
||||||
--8<-- "vllm/entrypoints/openai/protocol.py:embedding-extra-params"
|
--8<-- "vllm/entrypoints/pooling/embed/protocol.py:embedding-extra-params"
|
||||||
```
|
```
|
||||||
|
|
||||||
For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead:
|
For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead:
|
||||||
@ -359,7 +359,7 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
|
|||||||
??? code
|
??? code
|
||||||
|
|
||||||
```python
|
```python
|
||||||
--8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params"
|
--8<-- "vllm/entrypoints/pooling/embed/protocol.py:chat-embedding-extra-params"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Transcriptions API
|
### Transcriptions API
|
||||||
@ -629,7 +629,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
|
|||||||
The following extra parameters are supported:
|
The following extra parameters are supported:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
--8<-- "vllm/entrypoints/openai/protocol.py:classification-extra-params"
|
--8<-- "vllm/entrypoints/pooling/classify/protocol.py:classification-extra-params"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Score API
|
### Score API
|
||||||
@ -834,7 +834,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
|
|||||||
The following extra parameters are supported:
|
The following extra parameters are supported:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
--8<-- "vllm/entrypoints/openai/protocol.py:score-extra-params"
|
--8<-- "vllm/entrypoints/pooling/score/protocol.py:score-extra-params"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Re-rank API
|
### Re-rank API
|
||||||
@ -915,7 +915,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
|
|||||||
The following extra parameters are supported:
|
The following extra parameters are supported:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
--8<-- "vllm/entrypoints/openai/protocol.py:rerank-extra-params"
|
--8<-- "vllm/entrypoints/pooling/score/protocol.py:rerank-extra-params"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Ray Serve LLM
|
## Ray Serve LLM
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import tempfile
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
|
||||||
|
|
||||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse
|
from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
|
||||||
|
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||||
|
|
||||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||||
DTYPE = "float32" # Use float32 to avoid NaN issue
|
DTYPE = "float32" # Use float32 to avoid NaN issue
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import pytest
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import ClassificationResponse
|
from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
|
||||||
|
|
||||||
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
|
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
|
||||||
MAXIMUM_VIDEOS = 1
|
MAXIMUM_VIDEOS = 1
|
||||||
|
|||||||
@ -15,10 +15,8 @@ import torch.nn.functional as F
|
|||||||
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
|
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
|
||||||
from tests.models.utils import check_embeddings_close
|
from tests.models.utils import check_embeddings_close
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||||
EmbeddingResponse,
|
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||||
PoolingResponse,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils.serial_utils import (
|
from vllm.utils.serial_utils import (
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from tests.conftest import HfRunner
|
|||||||
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
|
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
|
||||||
from tests.models.utils import EmbedModelInfo
|
from tests.models.utils import EmbedModelInfo
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import requests
|
|||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
|
||||||
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import torch
|
|||||||
|
|
||||||
from tests.models.utils import check_embeddings_close
|
from tests.models.utils import check_embeddings_close
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import PoolingResponse
|
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils.serial_utils import (
|
from vllm.utils.serial_utils import (
|
||||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
|
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||||
|
from vllm.entrypoints.pooling.score.protocol import RerankResponse
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||||||
from torch import tensor
|
from torch import tensor
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import ScoreResponse
|
from vllm.entrypoints.pooling.score.protocol import ScoreResponse
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
|||||||
@ -18,7 +18,10 @@ from einops import rearrange
|
|||||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse
|
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||||
|
IOProcessorRequest,
|
||||||
|
IOProcessorResponse,
|
||||||
|
)
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import requests
|
|||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||||
from vllm.plugins.io_processors import get_io_processor
|
from vllm.plugins.io_processors import get_io_processor
|
||||||
|
|
||||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import socket
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
|
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
@ -54,29 +54,16 @@ from vllm.entrypoints.openai.orca_metrics import metrics_header
|
|||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ClassificationRequest,
|
|
||||||
ClassificationResponse,
|
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
DetokenizeResponse,
|
DetokenizeResponse,
|
||||||
EmbeddingBytesResponse,
|
|
||||||
EmbeddingRequest,
|
|
||||||
EmbeddingResponse,
|
|
||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
IOProcessorResponse,
|
|
||||||
PoolingBytesResponse,
|
|
||||||
PoolingRequest,
|
|
||||||
PoolingResponse,
|
|
||||||
RerankRequest,
|
|
||||||
RerankResponse,
|
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
ScoreRequest,
|
|
||||||
ScoreResponse,
|
|
||||||
StreamingResponsesResponse,
|
StreamingResponsesResponse,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
@ -86,17 +73,13 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
TranslationResponse,
|
TranslationResponse,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_classification import ServingClassification
|
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import (
|
from vllm.entrypoints.openai.serving_models import (
|
||||||
BaseModelPath,
|
BaseModelPath,
|
||||||
OpenAIServingModels,
|
OpenAIServingModels,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
|
||||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
|
||||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
@ -104,6 +87,11 @@ from vllm.entrypoints.openai.serving_transcription import (
|
|||||||
OpenAIServingTranslation,
|
OpenAIServingTranslation,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
|
from vllm.entrypoints.openai.utils import validate_json_request
|
||||||
|
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||||
|
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||||
|
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||||
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||||
from vllm.entrypoints.utils import (
|
from vllm.entrypoints.utils import (
|
||||||
cli_env_setup,
|
cli_env_setup,
|
||||||
@ -254,15 +242,6 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
async_llm.shutdown()
|
async_llm.shutdown()
|
||||||
|
|
||||||
|
|
||||||
async def validate_json_request(raw_request: Request):
|
|
||||||
content_type = raw_request.headers.get("content-type", "").lower()
|
|
||||||
media_type = content_type.split(";", maxsplit=1)[0]
|
|
||||||
if media_type != "application/json":
|
|
||||||
raise RequestValidationError(
|
|
||||||
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@ -324,26 +303,6 @@ def completion(request: Request) -> OpenAIServingCompletion | None:
|
|||||||
return request.app.state.openai_serving_completion
|
return request.app.state.openai_serving_completion
|
||||||
|
|
||||||
|
|
||||||
def pooling(request: Request) -> OpenAIServingPooling | None:
|
|
||||||
return request.app.state.openai_serving_pooling
|
|
||||||
|
|
||||||
|
|
||||||
def embedding(request: Request) -> OpenAIServingEmbedding | None:
|
|
||||||
return request.app.state.openai_serving_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def score(request: Request) -> ServingScores | None:
|
|
||||||
return request.app.state.openai_serving_scores
|
|
||||||
|
|
||||||
|
|
||||||
def classify(request: Request) -> ServingClassification | None:
|
|
||||||
return request.app.state.openai_serving_classification
|
|
||||||
|
|
||||||
|
|
||||||
def rerank(request: Request) -> ServingScores | None:
|
|
||||||
return request.app.state.openai_serving_scores
|
|
||||||
|
|
||||||
|
|
||||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
return request.app.state.openai_serving_tokenization
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
@ -817,166 +776,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/embeddings",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_embedding(
|
|
||||||
request: EmbeddingRequest,
|
|
||||||
raw_request: Request,
|
|
||||||
):
|
|
||||||
handler = embedding(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Embeddings API"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
generator = await handler.create_embedding(request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
elif isinstance(generator, EmbeddingResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
elif isinstance(generator, EmbeddingBytesResponse):
|
|
||||||
return StreamingResponse(
|
|
||||||
content=generator.body,
|
|
||||||
headers={"metadata": generator.metadata},
|
|
||||||
media_type=generator.media_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/pooling",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
|
||||||
handler = pooling(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Pooling API"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
generator = await handler.create_pooling(request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
elif isinstance(generator, PoolingBytesResponse):
|
|
||||||
return StreamingResponse(
|
|
||||||
content=generator.body,
|
|
||||||
headers={"metadata": generator.metadata},
|
|
||||||
media_type=generator.media_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
|
||||||
handler = classify(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Classification API"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
generator = await handler.create_classify(request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(generator, ClassificationResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/score",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
|
||||||
handler = score(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Score API"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
generator = await handler.create_score(request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
elif isinstance(generator, ScoreResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/score",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
|
||||||
logger.warning(
|
|
||||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
|
||||||
"have moved it to `/score`. Please update your client accordingly."
|
|
||||||
)
|
|
||||||
|
|
||||||
return await create_score(request, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/audio/transcriptions",
|
"/v1/audio/transcriptions",
|
||||||
responses={
|
responses={
|
||||||
@ -1055,70 +854,6 @@ async def create_translations(
|
|||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/rerank",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
|
||||||
handler = rerank(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Rerank (Score) API"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
generator = await handler.do_rerank(request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
elif isinstance(generator, RerankResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/rerank",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
|
||||||
logger.warning_once(
|
|
||||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
|
||||||
" API, we have located it at `/rerank`. Please update your client "
|
|
||||||
"accordingly. (Note: Conforms to JinaAI rerank API)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return await do_rerank(request, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v2/rerank",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
|
||||||
return await do_rerank(request, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_SERVER_DEV_MODE:
|
if envs.VLLM_SERVER_DEV_MODE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"SECURITY WARNING: Development endpoints are enabled! "
|
"SECURITY WARNING: Development endpoints are enabled! "
|
||||||
@ -1285,30 +1020,6 @@ async def is_scaling_elastic_ep(raw_request: Request):
|
|||||||
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
||||||
|
|
||||||
|
|
||||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
|
||||||
# (requires typing_extensions >= 4.13)
|
|
||||||
RequestType = Any
|
|
||||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
|
||||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
|
||||||
|
|
||||||
# NOTE: Items defined earlier take higher priority
|
|
||||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
|
||||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
|
||||||
(CompletionRequest, (completion, create_completion)),
|
|
||||||
(EmbeddingRequest, (embedding, create_embedding)),
|
|
||||||
(ClassificationRequest, (classify, create_classify)),
|
|
||||||
(ScoreRequest, (score, create_score)),
|
|
||||||
(RerankRequest, (rerank, do_rerank)),
|
|
||||||
(PoolingRequest, (pooling, create_pooling)),
|
|
||||||
]
|
|
||||||
|
|
||||||
# NOTE: Construct the TypeAdapters only once
|
|
||||||
INVOCATION_VALIDATORS = [
|
|
||||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
|
||||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/inference/v1/generate",
|
"/inference/v1/generate",
|
||||||
dependencies=[Depends(validate_json_request)],
|
dependencies=[Depends(validate_json_request)],
|
||||||
@ -1653,12 +1364,16 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||||
|
|
||||||
register_sagemaker_routes(router)
|
register_sagemaker_routes(router)
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
|
|
||||||
mount_metrics(app)
|
mount_metrics(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||||
|
|
||||||
|
register_pooling_api_routers(app)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=args.allowed_origins,
|
allow_origins=args.allowed_origins,
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
|
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
@ -48,14 +48,6 @@ from openai.types.responses.response_reasoning_item import (
|
|||||||
)
|
)
|
||||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||||
|
|
||||||
from vllm.config.pooler import get_use_activation
|
|
||||||
from vllm.tasks import PoolingTask
|
|
||||||
from vllm.utils.serial_utils import (
|
|
||||||
EmbedDType,
|
|
||||||
EncodingFormat,
|
|
||||||
Endianness,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backward compatibility for OpenAI client versions
|
# Backward compatibility for OpenAI client versions
|
||||||
try: # For older openai versions (< 1.100.0)
|
try: # For older openai versions (< 1.100.0)
|
||||||
from openai.types.responses import ResponseTextConfig
|
from openai.types.responses import ResponseTextConfig
|
||||||
@ -70,19 +62,14 @@ from pydantic import (
|
|||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
TypeAdapter,
|
|
||||||
ValidationError,
|
ValidationError,
|
||||||
ValidationInfo,
|
|
||||||
field_serializer,
|
field_serializer,
|
||||||
field_validator,
|
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
|
||||||
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
from vllm.sampling_params import (
|
from vllm.sampling_params import (
|
||||||
BeamSearchParams,
|
BeamSearchParams,
|
||||||
RequestOutputKind,
|
RequestOutputKind,
|
||||||
@ -1345,401 +1332,6 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
|
||||||
# Ordered by official OpenAI API documentation
|
|
||||||
# https://platform.openai.com/docs/api-reference/embeddings
|
|
||||||
model: str | None = None
|
|
||||||
input: list[int] | list[list[int]] | str | list[str]
|
|
||||||
encoding_format: EncodingFormat = "float"
|
|
||||||
dimensions: int | None = None
|
|
||||||
user: str | None = None
|
|
||||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
|
||||||
|
|
||||||
# --8<-- [start:embedding-extra-params]
|
|
||||||
add_special_tokens: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description=(
|
|
||||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
|
||||||
"the prompt."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
priority: int = Field(
|
|
||||||
default=0,
|
|
||||||
description=(
|
|
||||||
"The priority of the request (lower means earlier handling; "
|
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
|
||||||
"if the served model does not use priority scheduling."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
request_id: str = Field(
|
|
||||||
default_factory=random_uuid,
|
|
||||||
description=(
|
|
||||||
"The request_id related to this request. If the caller does "
|
|
||||||
"not set it, a random_uuid will be generated. This id is used "
|
|
||||||
"through out the inference process and return in response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
normalize: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
|
||||||
)
|
|
||||||
embed_dtype: EmbedDType = Field(
|
|
||||||
default="float32",
|
|
||||||
description=(
|
|
||||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
|
||||||
"encoding to match the OpenAI python client behavior. "
|
|
||||||
"This parameter will affect base64 and binary_response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
endianness: Endianness = Field(
|
|
||||||
default="native",
|
|
||||||
description=(
|
|
||||||
"What endianness to use for encoding. Default to using native for "
|
|
||||||
"base64 encoding to match the OpenAI python client behavior."
|
|
||||||
"This parameter will affect base64 and binary_response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# --8<-- [end:embedding-extra-params]
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
dimensions=self.dimensions,
|
|
||||||
normalize=self.normalize,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
messages: list[ChatCompletionMessageParam]
|
|
||||||
|
|
||||||
encoding_format: EncodingFormat = "float"
|
|
||||||
dimensions: int | None = None
|
|
||||||
user: str | None = None
|
|
||||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
|
||||||
|
|
||||||
# --8<-- [start:chat-embedding-extra-params]
|
|
||||||
add_generation_prompt: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description=(
|
|
||||||
"If true, the generation prompt will be added to the chat template. "
|
|
||||||
"This is a parameter used by chat template in tokenizer config of the "
|
|
||||||
"model."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
add_special_tokens: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description=(
|
|
||||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
|
||||||
"on top of what is added by the chat template. "
|
|
||||||
"For most models, the chat template takes care of adding the "
|
|
||||||
"special tokens so this should be set to false (as is the "
|
|
||||||
"default)."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
chat_template: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"A Jinja template to use for this conversion. "
|
|
||||||
"As of transformers v4.44, default chat template is no longer "
|
|
||||||
"allowed, so you must provide a chat template if the tokenizer "
|
|
||||||
"does not define one."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Additional keyword args to pass to the template renderer. "
|
|
||||||
"Will be accessible by the chat template."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=("Additional kwargs to pass to the HF processor."),
|
|
||||||
)
|
|
||||||
priority: int = Field(
|
|
||||||
default=0,
|
|
||||||
description=(
|
|
||||||
"The priority of the request (lower means earlier handling; "
|
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
|
||||||
"if the served model does not use priority scheduling."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
request_id: str = Field(
|
|
||||||
default_factory=random_uuid,
|
|
||||||
description=(
|
|
||||||
"The request_id related to this request. If the caller does "
|
|
||||||
"not set it, a random_uuid will be generated. This id is used "
|
|
||||||
"through out the inference process and return in response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
normalize: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
|
||||||
)
|
|
||||||
embed_dtype: EmbedDType = Field(
|
|
||||||
default="float32",
|
|
||||||
description=(
|
|
||||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
|
||||||
"encoding to match the OpenAI python client behavior. "
|
|
||||||
"This parameter will affect base64 and binary_response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
endianness: Endianness = Field(
|
|
||||||
default="native",
|
|
||||||
description=(
|
|
||||||
"What endianness to use for encoding. Default to using native for "
|
|
||||||
"base64 encoding to match the OpenAI python client behavior."
|
|
||||||
"This parameter will affect base64 and binary_response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# --8<-- [end:chat-embedding-extra-params]
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_generation_prompt(cls, data):
|
|
||||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot set both `continue_final_message` and "
|
|
||||||
"`add_generation_prompt` to True."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
dimensions=self.dimensions,
|
|
||||||
normalize=self.normalize,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingCompletionRequest(EmbeddingCompletionRequest):
|
|
||||||
task: PoolingTask | None = None
|
|
||||||
softmax: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="activation will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
use_activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to use activation for classification outputs. "
|
|
||||||
"If it is a classify or token_classify task, the default is True; "
|
|
||||||
"for other tasks, this value should be None.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
dimensions=self.dimensions,
|
|
||||||
normalize=self.normalize,
|
|
||||||
use_activation=get_use_activation(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingChatRequest(EmbeddingChatRequest):
|
|
||||||
task: PoolingTask | None = None
|
|
||||||
softmax: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="activation will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
use_activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to use activation for classification outputs. "
|
|
||||||
"If it is a classify or token_classify task, the default is True; "
|
|
||||||
"for other tasks, this value should be None.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
dimensions=self.dimensions,
|
|
||||||
normalize=self.normalize,
|
|
||||||
use_activation=get_use_activation(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
|
||||||
model: str | None = None
|
|
||||||
|
|
||||||
priority: int = Field(default=0)
|
|
||||||
"""
|
|
||||||
The priority of the request (lower means earlier handling;
|
|
||||||
default: 0). Any priority other than 0 will raise an error
|
|
||||||
if the served model does not use priority scheduling.
|
|
||||||
"""
|
|
||||||
data: T
|
|
||||||
|
|
||||||
task: PoolingTask = "plugin"
|
|
||||||
encoding_format: EncodingFormat = "float"
|
|
||||||
embed_dtype: EmbedDType = Field(
|
|
||||||
default="float32",
|
|
||||||
description=(
|
|
||||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
|
||||||
"encoding to match the OpenAI python client behavior. "
|
|
||||||
"This parameter will affect base64 and binary_response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
endianness: Endianness = Field(
|
|
||||||
default="native",
|
|
||||||
description=(
|
|
||||||
"What endianness to use for encoding. Default to using native for "
|
|
||||||
"base64 encoding to match the OpenAI python client behavior."
|
|
||||||
"This parameter will affect base64 and binary_response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams()
|
|
||||||
|
|
||||||
|
|
||||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
|
||||||
request_id: str | None = None
|
|
||||||
"""
|
|
||||||
The request_id associated with this response
|
|
||||||
"""
|
|
||||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
|
|
||||||
data: T
|
|
||||||
"""
|
|
||||||
When using plugins IOProcessor plugins, the actual output is generated
|
|
||||||
by the plugin itself. Hence, we use a generic type for the response data
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
PoolingRequest: TypeAlias = (
|
|
||||||
PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreRequest(OpenAIBaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
text_1: list[str] | str | ScoreMultiModalParam
|
|
||||||
text_2: list[str] | str | ScoreMultiModalParam
|
|
||||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
|
||||||
|
|
||||||
# --8<-- [start:score-extra-params]
|
|
||||||
|
|
||||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=("Additional kwargs to pass to the HF processor."),
|
|
||||||
)
|
|
||||||
|
|
||||||
priority: int = Field(
|
|
||||||
default=0,
|
|
||||||
description=(
|
|
||||||
"The priority of the request (lower means earlier handling; "
|
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
|
||||||
"if the served model does not use priority scheduling."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
softmax: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="activation will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
use_activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to use activation for classification outputs. "
|
|
||||||
"Default is True.",
|
|
||||||
)
|
|
||||||
# --8<-- [end:score-extra-params]
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
use_activation=get_use_activation(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RerankRequest(OpenAIBaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
query: str | ScoreMultiModalParam
|
|
||||||
documents: list[str] | ScoreMultiModalParam
|
|
||||||
top_n: int = Field(default_factory=lambda: 0)
|
|
||||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
|
||||||
|
|
||||||
# --8<-- [start:rerank-extra-params]
|
|
||||||
|
|
||||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=("Additional kwargs to pass to the HF processor."),
|
|
||||||
)
|
|
||||||
|
|
||||||
priority: int = Field(
|
|
||||||
default=0,
|
|
||||||
description=(
|
|
||||||
"The priority of the request (lower means earlier handling; "
|
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
|
||||||
"if the served model does not use priority scheduling."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
softmax: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="activation will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
use_activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to use activation for classification outputs. "
|
|
||||||
"Default is True.",
|
|
||||||
)
|
|
||||||
# --8<-- [end:rerank-extra-params]
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
use_activation=get_use_activation(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RerankDocument(BaseModel):
|
|
||||||
text: str | None = None
|
|
||||||
multi_modal: ScoreContentPartParam | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RerankResult(BaseModel):
|
|
||||||
index: int
|
|
||||||
document: RerankDocument
|
|
||||||
relevance_score: float
|
|
||||||
|
|
||||||
|
|
||||||
class RerankUsage(BaseModel):
|
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
class RerankResponse(OpenAIBaseModel):
|
|
||||||
id: str
|
|
||||||
model: str
|
|
||||||
usage: RerankUsage
|
|
||||||
results: list[RerankResult]
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionLogProbs(OpenAIBaseModel):
|
class CompletionLogProbs(OpenAIBaseModel):
|
||||||
text_offset: list[int] = Field(default_factory=list)
|
text_offset: list[int] = Field(default_factory=list)
|
||||||
token_logprobs: list[float | None] = Field(default_factory=list)
|
token_logprobs: list[float | None] = Field(default_factory=list)
|
||||||
@ -1809,229 +1401,6 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
|||||||
usage: UsageInfo | None = Field(default=None)
|
usage: UsageInfo | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponseData(OpenAIBaseModel):
|
|
||||||
index: int
|
|
||||||
object: str = "embedding"
|
|
||||||
embedding: list[float] | str
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(OpenAIBaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
|
||||||
object: str = "list"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
data: list[EmbeddingResponseData]
|
|
||||||
usage: UsageInfo
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingBytesResponse(OpenAIBaseModel):
|
|
||||||
body: list[bytes]
|
|
||||||
metadata: str
|
|
||||||
media_type: str = "application/octet-stream"
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingResponseData(OpenAIBaseModel):
|
|
||||||
index: int
|
|
||||||
object: str = "pooling"
|
|
||||||
data: list[list[float]] | list[float] | str
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingResponse(OpenAIBaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
|
|
||||||
object: str = "list"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
data: list[PoolingResponseData]
|
|
||||||
usage: UsageInfo
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingBytesResponse(OpenAIBaseModel):
|
|
||||||
body: list[bytes]
|
|
||||||
metadata: str
|
|
||||||
media_type: str = "application/octet-stream"
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreResponseData(OpenAIBaseModel):
|
|
||||||
index: int
|
|
||||||
object: str = "score"
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreResponse(OpenAIBaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
|
||||||
object: str = "list"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
data: list[ScoreResponseData]
|
|
||||||
usage: UsageInfo
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationCompletionRequest(OpenAIBaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
input: list[str] | str
|
|
||||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
|
||||||
user: str | None = None
|
|
||||||
|
|
||||||
# --8<-- [start:classification-extra-params]
|
|
||||||
priority: int = Field(
|
|
||||||
default=0,
|
|
||||||
description=(
|
|
||||||
"The priority of the request (lower means earlier handling; "
|
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
|
||||||
"if the served model does not use priority scheduling."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
add_special_tokens: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description=(
|
|
||||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
|
||||||
"the prompt."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
request_id: str = Field(
|
|
||||||
default_factory=random_uuid,
|
|
||||||
description=(
|
|
||||||
"The request_id related to this request. If the caller does "
|
|
||||||
"not set it, a random_uuid will be generated. This id is used "
|
|
||||||
"through out the inference process and return in response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
softmax: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="activation will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
use_activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to use activation for classification outputs. "
|
|
||||||
"Default is True.",
|
|
||||||
)
|
|
||||||
# --8<-- [end:classification-extra-params]
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
use_activation=get_use_activation(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationChatRequest(OpenAIBaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
messages: list[ChatCompletionMessageParam]
|
|
||||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
|
||||||
user: str | None = None
|
|
||||||
|
|
||||||
# --8<-- [start:chat-classification-extra-params]
|
|
||||||
add_generation_prompt: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description=(
|
|
||||||
"If true, the generation prompt will be added to the chat template. "
|
|
||||||
"This is a parameter used by chat template in tokenizer config of the "
|
|
||||||
"model."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
add_special_tokens: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description=(
|
|
||||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
|
||||||
"on top of what is added by the chat template. "
|
|
||||||
"For most models, the chat template takes care of adding the "
|
|
||||||
"special tokens so this should be set to false (as is the "
|
|
||||||
"default)."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_template: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"A Jinja template to use for this conversion. "
|
|
||||||
"As of transformers v4.44, default chat template is no longer "
|
|
||||||
"allowed, so you must provide a chat template if the tokenizer "
|
|
||||||
"does not define one."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Additional keyword args to pass to the template renderer. "
|
|
||||||
"Will be accessible by the chat template."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=("Additional kwargs to pass to the HF processor."),
|
|
||||||
)
|
|
||||||
|
|
||||||
priority: int = Field(
|
|
||||||
default=0,
|
|
||||||
description=(
|
|
||||||
"The priority of the request (lower means earlier handling; "
|
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
|
||||||
"if the served model does not use priority scheduling."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
request_id: str = Field(
|
|
||||||
default_factory=random_uuid,
|
|
||||||
description=(
|
|
||||||
"The request_id related to this request. If the caller does "
|
|
||||||
"not set it, a random_uuid will be generated. This id is used "
|
|
||||||
"through out the inference process and return in response."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
softmax: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="softmax will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="activation will be deprecated, please use use_activation instead.",
|
|
||||||
)
|
|
||||||
|
|
||||||
use_activation: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to use activation for classification outputs. "
|
|
||||||
"Default is True.",
|
|
||||||
)
|
|
||||||
# --8<-- [end:chat-classification-extra-params]
|
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
|
||||||
use_activation=get_use_activation(self),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClassificationRequest: TypeAlias = (
|
|
||||||
ClassificationCompletionRequest | ClassificationChatRequest
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationData(OpenAIBaseModel):
|
|
||||||
index: int
|
|
||||||
label: str | None
|
|
||||||
probs: list[float]
|
|
||||||
num_classes: int
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationResponse(OpenAIBaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
|
||||||
object: str = "list"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
data: list[ClassificationData]
|
|
||||||
usage: UsageInfo
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(OpenAIBaseModel):
|
class FunctionCall(OpenAIBaseModel):
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
arguments: str
|
||||||
@ -2409,83 +1778,6 @@ StreamingResponsesResponse: TypeAlias = (
|
|||||||
| ResponseCodeInterpreterCallCompletedEvent
|
| ResponseCodeInterpreterCallCompletedEvent
|
||||||
)
|
)
|
||||||
|
|
||||||
BatchRequestInputBody: TypeAlias = (
|
|
||||||
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchRequestInput(OpenAIBaseModel):
|
|
||||||
"""
|
|
||||||
The per-line object of the batch input file.
|
|
||||||
|
|
||||||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# A developer-provided per-request id that will be used to match outputs to
|
|
||||||
# inputs. Must be unique for each request in a batch.
|
|
||||||
custom_id: str
|
|
||||||
|
|
||||||
# The HTTP method to be used for the request. Currently only POST is
|
|
||||||
# supported.
|
|
||||||
method: str
|
|
||||||
|
|
||||||
# The OpenAI API relative URL to be used for the request. Currently
|
|
||||||
# /v1/chat/completions is supported.
|
|
||||||
url: str
|
|
||||||
|
|
||||||
# The parameters of the request.
|
|
||||||
body: BatchRequestInputBody
|
|
||||||
|
|
||||||
@field_validator("body", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
|
||||||
# Use url to disambiguate models
|
|
||||||
url: str = info.data["url"]
|
|
||||||
if url == "/v1/chat/completions":
|
|
||||||
return ChatCompletionRequest.model_validate(value)
|
|
||||||
if url == "/v1/embeddings":
|
|
||||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
|
||||||
if url.endswith("/score"):
|
|
||||||
return ScoreRequest.model_validate(value)
|
|
||||||
if url.endswith("/rerank"):
|
|
||||||
return RerankRequest.model_validate(value)
|
|
||||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchResponseData(OpenAIBaseModel):
|
|
||||||
# HTTP status code of the response.
|
|
||||||
status_code: int = 200
|
|
||||||
|
|
||||||
# An unique identifier for the API request.
|
|
||||||
request_id: str
|
|
||||||
|
|
||||||
# The body of the response.
|
|
||||||
body: (
|
|
||||||
ChatCompletionResponse
|
|
||||||
| EmbeddingResponse
|
|
||||||
| ScoreResponse
|
|
||||||
| RerankResponse
|
|
||||||
| None
|
|
||||||
) = None
|
|
||||||
|
|
||||||
|
|
||||||
class BatchRequestOutput(OpenAIBaseModel):
|
|
||||||
"""
|
|
||||||
The per-line object of the batch output and error files
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
|
|
||||||
# A developer-provided per-request id that will be used to match outputs to
|
|
||||||
# inputs.
|
|
||||||
custom_id: str
|
|
||||||
|
|
||||||
response: BatchResponseData | None
|
|
||||||
|
|
||||||
# For requests that failed with a non-HTTP error, this will contain more
|
|
||||||
# information on the cause of the failure.
|
|
||||||
error: Any | None
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
|||||||
@ -7,29 +7,35 @@ from argparse import Namespace
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from prometheus_client import start_http_server
|
from prometheus_client import start_http_server
|
||||||
|
from pydantic import TypeAdapter, field_validator
|
||||||
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
BatchRequestInput,
|
ChatCompletionRequest,
|
||||||
BatchRequestOutput,
|
|
||||||
BatchResponseData,
|
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
EmbeddingResponse,
|
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
RerankResponse,
|
OpenAIBaseModel,
|
||||||
ScoreResponse,
|
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
||||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
|
||||||
|
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.pooling.score.protocol import (
|
||||||
|
RerankRequest,
|
||||||
|
RerankResponse,
|
||||||
|
ScoreRequest,
|
||||||
|
ScoreResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -39,6 +45,84 @@ from vllm.version import __version__ as VLLM_VERSION
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
BatchRequestInputBody: TypeAlias = (
|
||||||
|
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequestInput(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
The per-line object of the batch input file.
|
||||||
|
|
||||||
|
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A developer-provided per-request id that will be used to match outputs to
|
||||||
|
# inputs. Must be unique for each request in a batch.
|
||||||
|
custom_id: str
|
||||||
|
|
||||||
|
# The HTTP method to be used for the request. Currently only POST is
|
||||||
|
# supported.
|
||||||
|
method: str
|
||||||
|
|
||||||
|
# The OpenAI API relative URL to be used for the request. Currently
|
||||||
|
# /v1/chat/completions is supported.
|
||||||
|
url: str
|
||||||
|
|
||||||
|
# The parameters of the request.
|
||||||
|
body: BatchRequestInputBody
|
||||||
|
|
||||||
|
@field_validator("body", mode="plain")
|
||||||
|
@classmethod
|
||||||
|
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||||
|
# Use url to disambiguate models
|
||||||
|
url: str = info.data["url"]
|
||||||
|
if url == "/v1/chat/completions":
|
||||||
|
return ChatCompletionRequest.model_validate(value)
|
||||||
|
if url == "/v1/embeddings":
|
||||||
|
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||||
|
if url.endswith("/score"):
|
||||||
|
return ScoreRequest.model_validate(value)
|
||||||
|
if url.endswith("/rerank"):
|
||||||
|
return RerankRequest.model_validate(value)
|
||||||
|
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchResponseData(OpenAIBaseModel):
|
||||||
|
# HTTP status code of the response.
|
||||||
|
status_code: int = 200
|
||||||
|
|
||||||
|
# An unique identifier for the API request.
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
# The body of the response.
|
||||||
|
body: (
|
||||||
|
ChatCompletionResponse
|
||||||
|
| EmbeddingResponse
|
||||||
|
| ScoreResponse
|
||||||
|
| RerankResponse
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequestOutput(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
The per-line object of the batch output and error files
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
|
||||||
|
# A developer-provided per-request id that will be used to match outputs to
|
||||||
|
# inputs.
|
||||||
|
custom_id: str
|
||||||
|
|
||||||
|
response: BatchResponseData | None
|
||||||
|
|
||||||
|
# For requests that failed with a non-HTTP error, this will contain more
|
||||||
|
# information on the cause of the failure.
|
||||||
|
error: Any | None
|
||||||
|
|
||||||
|
|
||||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-i",
|
"-i",
|
||||||
|
|||||||
@ -18,6 +18,28 @@ from pydantic import ConfigDict, TypeAdapter
|
|||||||
from starlette.datastructures import Headers
|
from starlette.datastructures import Headers
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
|
from vllm.entrypoints.pooling.classify.protocol import (
|
||||||
|
ClassificationChatRequest,
|
||||||
|
ClassificationCompletionRequest,
|
||||||
|
ClassificationRequest,
|
||||||
|
ClassificationResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.embed.protocol import (
|
||||||
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||||
|
IOProcessorRequest,
|
||||||
|
PoolingResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.score.protocol import (
|
||||||
|
RerankRequest,
|
||||||
|
ScoreRequest,
|
||||||
|
ScoreResponse,
|
||||||
|
)
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
else:
|
else:
|
||||||
@ -45,29 +67,16 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ClassificationChatRequest,
|
|
||||||
ClassificationCompletionRequest,
|
|
||||||
ClassificationRequest,
|
|
||||||
ClassificationResponse,
|
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
EmbeddingChatRequest,
|
|
||||||
EmbeddingCompletionRequest,
|
|
||||||
EmbeddingRequest,
|
|
||||||
EmbeddingResponse,
|
|
||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
IOProcessorRequest,
|
|
||||||
PoolingResponse,
|
|
||||||
RerankRequest,
|
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ScoreRequest,
|
|
||||||
ScoreResponse,
|
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
|
|||||||
@ -2,6 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
@ -35,3 +38,12 @@ def maybe_filter_parallel_tool_calls(
|
|||||||
]
|
]
|
||||||
|
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_json_request(raw_request: Request):
|
||||||
|
content_type = raw_request.headers.get("content-type", "").lower()
|
||||||
|
media_type = content_type.split(";", maxsplit=1)[0]
|
||||||
|
if media_type != "application/json":
|
||||||
|
raise RequestValidationError(
|
||||||
|
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
|
||||||
|
)
|
||||||
|
|||||||
16
vllm/entrypoints/pooling/__init__.py
Normal file
16
vllm/entrypoints/pooling/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
def register_pooling_api_routers(app: FastAPI):
|
||||||
|
from vllm.entrypoints.pooling.classify.api_router import router as classify_router
|
||||||
|
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
|
||||||
|
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
|
||||||
|
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||||
|
|
||||||
|
app.include_router(classify_router)
|
||||||
|
app.include_router(embed_router)
|
||||||
|
app.include_router(score_router)
|
||||||
|
app.include_router(pooling_router)
|
||||||
0
vllm/entrypoints/pooling/classify/__init__.py
Normal file
0
vllm/entrypoints/pooling/classify/__init__.py
Normal file
50
vllm/entrypoints/pooling/classify/api_router.py
Normal file
50
vllm/entrypoints/pooling/classify/api_router.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.utils import validate_json_request
|
||||||
|
from vllm.entrypoints.pooling.classify.protocol import (
|
||||||
|
ClassificationRequest,
|
||||||
|
ClassificationResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||||
|
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def classify(request: Request) -> ServingClassification | None:
|
||||||
|
return request.app.state.openai_serving_classification
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||||
|
handler = classify(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Classification API"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator = await handler.create_classify(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(generator, ClassificationResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
181
vllm/entrypoints/pooling/classify/protocol.py
Normal file
181
vllm/entrypoints/pooling/classify/protocol.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Annotated, Any, TypeAlias
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
Field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
|
from vllm.config.pooler import get_use_activation
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
input: list[str] | str
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:classification-extra-params]
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||||
|
"the prompt."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=random_uuid,
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
|
# --8<-- [end:classification-extra-params]
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationChatRequest(OpenAIBaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
messages: list[ChatCompletionMessageParam]
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:chat-classification-extra-params]
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_template: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Additional keyword args to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the HF processor."),
|
||||||
|
)
|
||||||
|
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=random_uuid,
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
|
# --8<-- [end:chat-classification-extra-params]
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationRequest: TypeAlias = (
|
||||||
|
ClassificationCompletionRequest | ClassificationChatRequest
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
label: str | None
|
||||||
|
probs: list[float]
|
||||||
|
num_classes: int
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: list[ClassificationData]
|
||||||
|
usage: UsageInfo
|
||||||
@ -13,11 +13,6 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
|||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ClassificationChatRequest,
|
|
||||||
ClassificationCompletionRequest,
|
|
||||||
ClassificationData,
|
|
||||||
ClassificationRequest,
|
|
||||||
ClassificationResponse,
|
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
@ -27,6 +22,13 @@ from vllm.entrypoints.openai.serving_engine import (
|
|||||||
ServeContext,
|
ServeContext,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.pooling.classify.protocol import (
|
||||||
|
ClassificationChatRequest,
|
||||||
|
ClassificationCompletionRequest,
|
||||||
|
ClassificationData,
|
||||||
|
ClassificationRequest,
|
||||||
|
ClassificationResponse,
|
||||||
|
)
|
||||||
from vllm.entrypoints.renderer import RenderConfig
|
from vllm.entrypoints.renderer import RenderConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||||
0
vllm/entrypoints/pooling/embed/__init__.py
Normal file
0
vllm/entrypoints/pooling/embed/__init__.py
Normal file
67
vllm/entrypoints/pooling/embed/api_router.py
Normal file
67
vllm/entrypoints/pooling/embed/api_router.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.utils import validate_json_request
|
||||||
|
from vllm.entrypoints.pooling.embed.protocol import (
|
||||||
|
EmbeddingBytesResponse,
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def embedding(request: Request) -> OpenAIServingEmbedding | None:
|
||||||
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/embeddings",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_embedding(
|
||||||
|
request: EmbeddingRequest,
|
||||||
|
raw_request: Request,
|
||||||
|
):
|
||||||
|
handler = embedding(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Embeddings API"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator = await handler.create_embedding(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
elif isinstance(generator, EmbeddingResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
elif isinstance(generator, EmbeddingBytesResponse):
|
||||||
|
return StreamingResponse(
|
||||||
|
content=generator.body,
|
||||||
|
headers={"metadata": generator.metadata},
|
||||||
|
media_type=generator.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
208
vllm/entrypoints/pooling/embed/protocol.py
Normal file
208
vllm/entrypoints/pooling/embed/protocol.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import time
|
||||||
|
from typing import Annotated, Any, TypeAlias
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
Field,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
model: str | None = None
|
||||||
|
input: list[int] | list[list[int]] | str | list[str]
|
||||||
|
encoding_format: EncodingFormat = "float"
|
||||||
|
dimensions: int | None = None
|
||||||
|
user: str | None = None
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:embedding-extra-params]
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||||
|
"the prompt."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=random_uuid,
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
normalize: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||||
|
)
|
||||||
|
embed_dtype: EmbedDType = Field(
|
||||||
|
default="float32",
|
||||||
|
description=(
|
||||||
|
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||||
|
"encoding to match the OpenAI python client behavior. "
|
||||||
|
"This parameter will affect base64 and binary_response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
endianness: Endianness = Field(
|
||||||
|
default="native",
|
||||||
|
description=(
|
||||||
|
"What endianness to use for encoding. Default to using native for "
|
||||||
|
"base64 encoding to match the OpenAI python client behavior."
|
||||||
|
"This parameter will affect base64 and binary_response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# --8<-- [end:embedding-extra-params]
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
messages: list[ChatCompletionMessageParam]
|
||||||
|
|
||||||
|
encoding_format: EncodingFormat = "float"
|
||||||
|
dimensions: int | None = None
|
||||||
|
user: str | None = None
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:chat-embedding-extra-params]
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chat_template: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Additional keyword args to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the HF processor."),
|
||||||
|
)
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=random_uuid,
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
normalize: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||||
|
)
|
||||||
|
embed_dtype: EmbedDType = Field(
|
||||||
|
default="float32",
|
||||||
|
description=(
|
||||||
|
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||||
|
"encoding to match the OpenAI python client behavior. "
|
||||||
|
"This parameter will affect base64 and binary_response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
endianness: Endianness = Field(
|
||||||
|
default="native",
|
||||||
|
description=(
|
||||||
|
"What endianness to use for encoding. Default to using native for "
|
||||||
|
"base64 encoding to match the OpenAI python client behavior."
|
||||||
|
"This parameter will affect base64 and binary_response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# --8<-- [end:chat-embedding-extra-params]
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_generation_prompt(cls, data):
|
||||||
|
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot set both `continue_final_message` and "
|
||||||
|
"`add_generation_prompt` to True."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponseData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: list[float] | str
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: list[EmbeddingResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||||
|
body: list[bytes]
|
||||||
|
metadata: str
|
||||||
|
media_type: str = "application/octet-stream"
|
||||||
@ -13,12 +13,6 @@ from vllm.engine.protocol import EngineClient
|
|||||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
EmbeddingBytesResponse,
|
|
||||||
EmbeddingChatRequest,
|
|
||||||
EmbeddingCompletionRequest,
|
|
||||||
EmbeddingRequest,
|
|
||||||
EmbeddingResponse,
|
|
||||||
EmbeddingResponseData,
|
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
@ -29,6 +23,14 @@ from vllm.entrypoints.openai.serving_engine import (
|
|||||||
TextTokensPrompt,
|
TextTokensPrompt,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.pooling.embed.protocol import (
|
||||||
|
EmbeddingBytesResponse,
|
||||||
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse,
|
||||||
|
EmbeddingResponseData,
|
||||||
|
)
|
||||||
from vllm.entrypoints.renderer import RenderConfig
|
from vllm.entrypoints.renderer import RenderConfig
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
0
vllm/entrypoints/pooling/pooling/__init__.py
Normal file
0
vllm/entrypoints/pooling/pooling/__init__.py
Normal file
63
vllm/entrypoints/pooling/pooling/api_router.py
Normal file
63
vllm/entrypoints/pooling/pooling/api_router.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.utils import validate_json_request
|
||||||
|
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||||
|
IOProcessorResponse,
|
||||||
|
PoolingBytesResponse,
|
||||||
|
PoolingRequest,
|
||||||
|
PoolingResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||||
|
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def pooling(request: Request) -> OpenAIServingPooling | None:
|
||||||
|
return request.app.state.openai_serving_pooling
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/pooling",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||||
|
handler = pooling(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Pooling API"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generator = await handler.create_pooling(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
elif isinstance(generator, PoolingBytesResponse):
|
||||||
|
return StreamingResponse(
|
||||||
|
content=generator.body,
|
||||||
|
headers={"metadata": generator.metadata},
|
||||||
|
media_type=generator.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
148
vllm/entrypoints/pooling/pooling/protocol.py
Normal file
148
vllm/entrypoints/pooling/pooling/protocol.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import time
|
||||||
|
from typing import Generic, TypeAlias, TypeVar
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
Field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
|
from vllm.config.pooler import get_use_activation
|
||||||
|
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||||
|
from vllm.entrypoints.pooling.embed.protocol import (
|
||||||
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
|
)
|
||||||
|
from vllm.tasks import PoolingTask
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingCompletionRequest(EmbeddingCompletionRequest):
|
||||||
|
task: PoolingTask | None = None
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"If it is a classify or token_classify task, the default is True; "
|
||||||
|
"for other tasks, this value should be None.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingChatRequest(EmbeddingChatRequest):
|
||||||
|
task: PoolingTask | None = None
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"If it is a classify or token_classify task, the default is True; "
|
||||||
|
"for other tasks, this value should be None.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||||
|
model: str | None = None
|
||||||
|
|
||||||
|
priority: int = Field(default=0)
|
||||||
|
"""
|
||||||
|
The priority of the request (lower means earlier handling;
|
||||||
|
default: 0). Any priority other than 0 will raise an error
|
||||||
|
if the served model does not use priority scheduling.
|
||||||
|
"""
|
||||||
|
data: T
|
||||||
|
|
||||||
|
task: PoolingTask = "plugin"
|
||||||
|
encoding_format: EncodingFormat = "float"
|
||||||
|
embed_dtype: EmbedDType = Field(
|
||||||
|
default="float32",
|
||||||
|
description=(
|
||||||
|
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||||
|
"encoding to match the OpenAI python client behavior. "
|
||||||
|
"This parameter will affect base64 and binary_response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
endianness: Endianness = Field(
|
||||||
|
default="native",
|
||||||
|
description=(
|
||||||
|
"What endianness to use for encoding. Default to using native for "
|
||||||
|
"base64 encoding to match the OpenAI python client behavior."
|
||||||
|
"This parameter will affect base64 and binary_response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams()
|
||||||
|
|
||||||
|
|
||||||
|
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||||
|
request_id: str | None = None
|
||||||
|
"""
|
||||||
|
The request_id associated with this response
|
||||||
|
"""
|
||||||
|
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
|
||||||
|
data: T
|
||||||
|
"""
|
||||||
|
When using plugins IOProcessor plugins, the actual output is generated
|
||||||
|
by the plugin itself. Hence, we use a generic type for the response data
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PoolingRequest: TypeAlias = (
|
||||||
|
PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingResponseData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "pooling"
|
||||||
|
data: list[list[float]] | list[float] | str
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: list[PoolingResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingBytesResponse(OpenAIBaseModel):
|
||||||
|
body: list[bytes]
|
||||||
|
metadata: str
|
||||||
|
media_type: str = "application/octet-stream"
|
||||||
@ -16,6 +16,11 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
|||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||||
IOProcessorRequest,
|
IOProcessorRequest,
|
||||||
IOProcessorResponse,
|
IOProcessorResponse,
|
||||||
PoolingBytesResponse,
|
PoolingBytesResponse,
|
||||||
@ -24,10 +29,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
PoolingRequest,
|
PoolingRequest,
|
||||||
PoolingResponse,
|
PoolingResponse,
|
||||||
PoolingResponseData,
|
PoolingResponseData,
|
||||||
UsageInfo,
|
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|
||||||
from vllm.entrypoints.renderer import RenderConfig
|
from vllm.entrypoints.renderer import RenderConfig
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
0
vllm/entrypoints/pooling/score/__init__.py
Normal file
0
vllm/entrypoints/pooling/score/__init__.py
Normal file
149
vllm/entrypoints/pooling/score/api_router.py
Normal file
149
vllm/entrypoints/pooling/score/api_router.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.utils import validate_json_request
|
||||||
|
from vllm.entrypoints.pooling.score.protocol import (
|
||||||
|
RerankRequest,
|
||||||
|
RerankResponse,
|
||||||
|
ScoreRequest,
|
||||||
|
ScoreResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||||
|
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def score(request: Request) -> ServingScores | None:
|
||||||
|
return request.app.state.openai_serving_scores
|
||||||
|
|
||||||
|
|
||||||
|
def rerank(request: Request) -> ServingScores | None:
|
||||||
|
return request.app.state.openai_serving_scores
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/score",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||||
|
handler = score(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Score API"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator = await handler.create_score(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
elif isinstance(generator, ScoreResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/score",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||||
|
logger.warning(
|
||||||
|
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||||
|
"have moved it to `/score`. Please update your client accordingly."
|
||||||
|
)
|
||||||
|
|
||||||
|
return await create_score(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/rerank",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||||
|
handler = rerank(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Rerank (Score) API"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generator = await handler.do_rerank(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
elif isinstance(generator, RerankResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/rerank",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||||
|
logger.warning_once(
|
||||||
|
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||||
|
" API, we have located it at `/rerank`. Please update your client "
|
||||||
|
"accordingly. (Note: Conforms to JinaAI rerank API)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return await do_rerank(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v2/rerank",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||||
|
return await do_rerank(request, raw_request)
|
||||||
145
vllm/entrypoints/pooling/score/protocol.py
Normal file
145
vllm/entrypoints/pooling/score/protocol.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import time
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
|
from vllm.config.pooler import get_use_activation
|
||||||
|
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||||
|
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreRequest(OpenAIBaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
text_1: list[str] | str | ScoreMultiModalParam
|
||||||
|
text_2: list[str] | str | ScoreMultiModalParam
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:score-extra-params]
|
||||||
|
|
||||||
|
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the HF processor."),
|
||||||
|
)
|
||||||
|
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
|
# --8<-- [end:score-extra-params]
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankRequest(OpenAIBaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
query: str | ScoreMultiModalParam
|
||||||
|
documents: list[str] | ScoreMultiModalParam
|
||||||
|
top_n: int = Field(default_factory=lambda: 0)
|
||||||
|
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||||
|
|
||||||
|
# --8<-- [start:rerank-extra-params]
|
||||||
|
|
||||||
|
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the HF processor."),
|
||||||
|
)
|
||||||
|
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
softmax: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="softmax will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="activation will be deprecated, please use use_activation instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_activation: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to use activation for classification outputs. "
|
||||||
|
"Default is True.",
|
||||||
|
)
|
||||||
|
# --8<-- [end:rerank-extra-params]
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=get_use_activation(self),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankDocument(BaseModel):
|
||||||
|
text: str | None = None
|
||||||
|
multi_modal: ScoreContentPartParam | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RerankResult(BaseModel):
|
||||||
|
index: int
|
||||||
|
document: RerankDocument
|
||||||
|
relevance_score: float
|
||||||
|
|
||||||
|
|
||||||
|
class RerankUsage(BaseModel):
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class RerankResponse(OpenAIBaseModel):
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
usage: RerankUsage
|
||||||
|
results: list[RerankResult]
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreResponseData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "score"
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: list[ScoreResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
@ -11,6 +11,11 @@ from vllm.engine.protocol import EngineClient
|
|||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.pooling.score.protocol import (
|
||||||
RerankDocument,
|
RerankDocument,
|
||||||
RerankRequest,
|
RerankRequest,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
@ -19,10 +24,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ScoreRequest,
|
ScoreRequest,
|
||||||
ScoreResponse,
|
ScoreResponse,
|
||||||
ScoreResponseData,
|
ScoreResponseData,
|
||||||
UsageInfo,
|
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|
||||||
from vllm.entrypoints.score_utils import (
|
from vllm.entrypoints.score_utils import (
|
||||||
ScoreContentPartParam,
|
ScoreContentPartParam,
|
||||||
ScoreMultiModalParam,
|
ScoreMultiModalParam,
|
||||||
@ -1,7 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -9,12 +11,56 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
INVOCATION_VALIDATORS,
|
|
||||||
base,
|
base,
|
||||||
|
chat,
|
||||||
|
completion,
|
||||||
|
create_chat_completion,
|
||||||
|
create_completion,
|
||||||
health,
|
health,
|
||||||
validate_json_request,
|
validate_json_request,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
ErrorResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
|
||||||
|
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
|
||||||
|
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
|
||||||
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||||
|
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
|
||||||
|
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
|
||||||
|
from vllm.entrypoints.pooling.score.api_router import (
|
||||||
|
create_score,
|
||||||
|
do_rerank,
|
||||||
|
rerank,
|
||||||
|
score,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||||
|
|
||||||
|
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||||
|
# (requires typing_extensions >= 4.13)
|
||||||
|
RequestType = Any
|
||||||
|
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||||
|
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||||
|
|
||||||
|
# NOTE: Items defined earlier take higher priority
|
||||||
|
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||||
|
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||||
|
(CompletionRequest, (completion, create_completion)),
|
||||||
|
(EmbeddingRequest, (embedding, create_embedding)),
|
||||||
|
(ClassificationRequest, (classify, create_classify)),
|
||||||
|
(ScoreRequest, (score, create_score)),
|
||||||
|
(RerankRequest, (rerank, do_rerank)),
|
||||||
|
(PoolingRequest, (pooling, create_pooling)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: Construct the TypeAdapters only once
|
||||||
|
INVOCATION_VALIDATORS = [
|
||||||
|
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||||
|
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def register_sagemaker_routes(router: APIRouter):
|
def register_sagemaker_routes(router: APIRouter):
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence
|
|||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user