mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:24:54 +08:00
[Frontend] Resettle pooling entrypoints (#29634)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
83805a6078
commit
62de4f4257
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -149,6 +149,7 @@ mkdocs.yaml @hmellor
|
||||
/examples/*/pooling/ @noooop
|
||||
/tests/models/*/pooling* @noooop
|
||||
/tests/entrypoints/pooling @noooop
|
||||
/vllm/entrypoints/pooling @aarnphm @chaunceyjiang @noooop
|
||||
/vllm/config/pooler.py @noooop
|
||||
/vllm/pooling_params.py @noooop
|
||||
/vllm/model_executor/layers/pooler.py @noooop
|
||||
|
||||
@ -77,7 +77,7 @@ The `parse_request` method is used for validating the user prompt and converting
|
||||
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
||||
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
||||
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
|
||||
|
||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
||||
|
||||
|
||||
@ -351,7 +351,7 @@ The following extra parameters are supported by default:
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:embedding-extra-params"
|
||||
--8<-- "vllm/entrypoints/pooling/embed/protocol.py:embedding-extra-params"
|
||||
```
|
||||
|
||||
For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead:
|
||||
@ -359,7 +359,7 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params"
|
||||
--8<-- "vllm/entrypoints/pooling/embed/protocol.py:chat-embedding-extra-params"
|
||||
```
|
||||
|
||||
### Transcriptions API
|
||||
@ -629,7 +629,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
The following extra parameters are supported:
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:classification-extra-params"
|
||||
--8<-- "vllm/entrypoints/pooling/classify/protocol.py:classification-extra-params"
|
||||
```
|
||||
|
||||
### Score API
|
||||
@ -834,7 +834,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
The following extra parameters are supported:
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:score-extra-params"
|
||||
--8<-- "vllm/entrypoints/pooling/score/protocol.py:score-extra-params"
|
||||
```
|
||||
|
||||
### Re-rank API
|
||||
@ -915,7 +915,7 @@ The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
The following extra parameters are supported:
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:rerank-extra-params"
|
||||
--8<-- "vllm/entrypoints/pooling/score/protocol.py:rerank-extra-params"
|
||||
```
|
||||
|
||||
## Ray Serve LLM
|
||||
|
||||
@ -7,7 +7,7 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
||||
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse
|
||||
from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
|
||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
DTYPE = "float32" # Use float32 to avoid NaN issue
|
||||
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import ClassificationResponse
|
||||
from vllm.entrypoints.pooling.classify.protocol import ClassificationResponse
|
||||
|
||||
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
|
||||
MAXIMUM_VIDEOS = 1
|
||||
|
||||
@ -15,10 +15,8 @@ import torch.nn.functional as F
|
||||
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
EmbeddingResponse,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils.serial_utils import (
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.conftest import HfRunner
|
||||
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
|
||||
from tests.models.utils import EmbedModelInfo
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
|
||||
@ -15,7 +15,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
|
||||
@ -8,7 +8,7 @@ import requests
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import PoolingResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
|
||||
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
||||
from torch import tensor
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.openai.protocol import ScoreResponse
|
||||
from vllm.entrypoints.pooling.score.protocol import ScoreResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
|
||||
@ -18,7 +18,10 @@ from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
)
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
|
||||
@ -7,7 +7,7 @@ import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
|
||||
@ -14,7 +14,7 @@ import socket
|
||||
import tempfile
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Literal
|
||||
@ -54,29 +54,16 @@ from vllm.entrypoints.openai.orca_metrics import metrics_header
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
StreamingResponsesResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
@ -86,17 +73,13 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranslationResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_classification import ServingClassification
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (
|
||||
BaseModelPath,
|
||||
OpenAIServingModels,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
@ -104,6 +87,11 @@ from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
@ -254,15 +242,6 @@ async def build_async_engine_client_from_engine_args(
|
||||
async_llm.shutdown()
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
media_type = content_type.split(";", maxsplit=1)[0]
|
||||
if media_type != "application/json":
|
||||
raise RequestValidationError(
|
||||
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@ -324,26 +303,6 @@ def completion(request: Request) -> OpenAIServingCompletion | None:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def pooling(request: Request) -> OpenAIServingPooling | None:
|
||||
return request.app.state.openai_serving_pooling
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding | None:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def score(request: Request) -> ServingScores | None:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def classify(request: Request) -> ServingClassification | None:
|
||||
return request.app.state.openai_serving_classification
|
||||
|
||||
|
||||
def rerank(request: Request) -> ServingScores | None:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@ -817,166 +776,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request,
|
||||
):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Embeddings API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_embedding(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
elif isinstance(generator, EmbeddingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.body,
|
||||
headers={"metadata": generator.metadata},
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/pooling",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
handler = pooling(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Pooling API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.create_pooling(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
elif isinstance(generator, PoolingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.body,
|
||||
headers={"metadata": generator.metadata},
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Score API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||
"have moved it to `/score`. Please update your client accordingly."
|
||||
)
|
||||
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
responses={
|
||||
@ -1055,70 +854,6 @@ async def create_translations(
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Rerank (Score) API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning_once(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client "
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)"
|
||||
)
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning(
|
||||
"SECURITY WARNING: Development endpoints are enabled! "
|
||||
@ -1285,30 +1020,6 @@ async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
||||
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||
(CompletionRequest, (completion, create_completion)),
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
(ClassificationRequest, (classify, create_classify)),
|
||||
(ScoreRequest, (score, create_score)),
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
(PoolingRequest, (pooling, create_pooling)),
|
||||
]
|
||||
|
||||
# NOTE: Construct the TypeAdapters only once
|
||||
INVOCATION_VALIDATORS = [
|
||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
@ -1653,12 +1364,16 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||
|
||||
register_sagemaker_routes(router)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
app.root_path = args.root_path
|
||||
|
||||
mount_metrics(app)
|
||||
|
||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||
|
||||
register_pooling_api_routers(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
import json
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
|
||||
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -48,14 +48,6 @@ from openai.types.responses.response_reasoning_item import (
|
||||
)
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.serial_utils import (
|
||||
EmbedDType,
|
||||
EncodingFormat,
|
||||
Endianness,
|
||||
)
|
||||
|
||||
# Backward compatibility for OpenAI client versions
|
||||
try: # For older openai versions (< 1.100.0)
|
||||
from openai.types.responses import ResponseTextConfig
|
||||
@ -70,19 +62,14 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
ValidationInfo,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
|
||||
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
@ -1345,401 +1332,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str | None = None
|
||||
input: list[int] | list[list[int]] | str | list[str]
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:embedding-extra-params]
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||
)
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:embedding-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:chat-embedding-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||
)
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:chat-embedding-extra-params]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
class PoolingCompletionRequest(EmbeddingCompletionRequest):
|
||||
task: PoolingTask | None = None
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"If it is a classify or token_classify task, the default is True; "
|
||||
"for other tasks, this value should be None.",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class PoolingChatRequest(EmbeddingChatRequest):
|
||||
task: PoolingTask | None = None
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"If it is a classify or token_classify task, the default is True; "
|
||||
"for other tasks, this value should be None.",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
model: str | None = None
|
||||
|
||||
priority: int = Field(default=0)
|
||||
"""
|
||||
The priority of the request (lower means earlier handling;
|
||||
default: 0). Any priority other than 0 will raise an error
|
||||
if the served model does not use priority scheduling.
|
||||
"""
|
||||
data: T
|
||||
|
||||
task: PoolingTask = "plugin"
|
||||
encoding_format: EncodingFormat = "float"
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams()
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
request_id: str | None = None
|
||||
"""
|
||||
The request_id associated with this response
|
||||
"""
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual output is generated
|
||||
by the plugin itself. Hence, we use a generic type for the response data
|
||||
"""
|
||||
|
||||
|
||||
PoolingRequest: TypeAlias = (
|
||||
PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
|
||||
)
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
text_1: list[str] | str | ScoreMultiModalParam
|
||||
text_2: list[str] | str | ScoreMultiModalParam
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:score-extra-params]
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
query: str | ScoreMultiModalParam
|
||||
documents: list[str] | ScoreMultiModalParam
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:rerank-extra-params]
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str | None = None
|
||||
multi_modal: ScoreContentPartParam | None = None
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
index: int
|
||||
document: RerankDocument
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RerankUsage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankResponse(OpenAIBaseModel):
|
||||
id: str
|
||||
model: str
|
||||
usage: RerankUsage
|
||||
results: list[RerankResult]
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: list[int] = Field(default_factory=list)
|
||||
token_logprobs: list[float | None] = Field(default_factory=list)
|
||||
@ -1809,229 +1401,6 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: list[float] | str
|
||||
|
||||
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||
body: list[bytes]
|
||||
metadata: str
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
|
||||
class PoolingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "pooling"
|
||||
data: list[list[float]] | list[float] | str
|
||||
|
||||
|
||||
class PoolingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[PoolingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class PoolingBytesResponse(OpenAIBaseModel):
|
||||
body: list[bytes]
|
||||
metadata: str
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
|
||||
class ScoreResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "score"
|
||||
score: float
|
||||
|
||||
|
||||
class ScoreResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ScoreResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
input: list[str] | str
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:classification-extra-params]
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-classification-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:chat-classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
ClassificationRequest: TypeAlias = (
|
||||
ClassificationCompletionRequest | ClassificationChatRequest
|
||||
)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
index: int
|
||||
label: str | None
|
||||
probs: list[float]
|
||||
num_classes: int
|
||||
|
||||
|
||||
class ClassificationResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ClassificationData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
@ -2409,83 +1778,6 @@ StreamingResponsesResponse: TypeAlias = (
|
||||
| ResponseCodeInterpreterCallCompletedEvent
|
||||
)
|
||||
|
||||
BatchRequestInputBody: TypeAlias = (
|
||||
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
|
||||
)
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
|
||||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||
"""
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs. Must be unique for each request in a batch.
|
||||
custom_id: str
|
||||
|
||||
# The HTTP method to be used for the request. Currently only POST is
|
||||
# supported.
|
||||
method: str
|
||||
|
||||
# The OpenAI API relative URL to be used for the request. Currently
|
||||
# /v1/chat/completions is supported.
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: BatchRequestInputBody
|
||||
|
||||
@field_validator("body", mode="plain")
|
||||
@classmethod
|
||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||
# Use url to disambiguate models
|
||||
url: str = info.data["url"]
|
||||
if url == "/v1/chat/completions":
|
||||
return ChatCompletionRequest.model_validate(value)
|
||||
if url == "/v1/embeddings":
|
||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||
if url.endswith("/score"):
|
||||
return ScoreRequest.model_validate(value)
|
||||
if url.endswith("/rerank"):
|
||||
return RerankRequest.model_validate(value)
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
|
||||
# An unique identifier for the API request.
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| None
|
||||
) = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch output and error files
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs.
|
||||
custom_id: str
|
||||
|
||||
response: BatchResponseData | None
|
||||
|
||||
# For requests that failed with a non-HTTP error, this will contain more
|
||||
# information on the cause of the failure.
|
||||
error: Any | None
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
|
||||
@ -7,29 +7,35 @@ from argparse import Namespace
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from pydantic import TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
RerankResponse,
|
||||
ScoreResponse,
|
||||
OpenAIBaseModel,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.utils import random_uuid
|
||||
@ -39,6 +45,84 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
BatchRequestInputBody: TypeAlias = (
|
||||
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
|
||||
)
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
|
||||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||
"""
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs. Must be unique for each request in a batch.
|
||||
custom_id: str
|
||||
|
||||
# The HTTP method to be used for the request. Currently only POST is
|
||||
# supported.
|
||||
method: str
|
||||
|
||||
# The OpenAI API relative URL to be used for the request. Currently
|
||||
# /v1/chat/completions is supported.
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: BatchRequestInputBody
|
||||
|
||||
@field_validator("body", mode="plain")
|
||||
@classmethod
|
||||
def check_type_for_url(cls, value: Any, info: ValidationInfo):
|
||||
# Use url to disambiguate models
|
||||
url: str = info.data["url"]
|
||||
if url == "/v1/chat/completions":
|
||||
return ChatCompletionRequest.model_validate(value)
|
||||
if url == "/v1/embeddings":
|
||||
return TypeAdapter(EmbeddingRequest).validate_python(value)
|
||||
if url.endswith("/score"):
|
||||
return ScoreRequest.model_validate(value)
|
||||
if url.endswith("/rerank"):
|
||||
return RerankRequest.model_validate(value)
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
|
||||
# An unique identifier for the API request.
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| None
|
||||
) = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch output and error files
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs.
|
||||
custom_id: str
|
||||
|
||||
response: BatchResponseData | None
|
||||
|
||||
# For requests that failed with a non-HTTP error, this will contain more
|
||||
# information on the cause of the failure.
|
||||
error: Any | None
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
|
||||
@ -18,6 +18,28 @@ from pydantic import ConfigDict, TypeAdapter
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
@ -45,29 +67,16 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse,
|
||||
RerankRequest,
|
||||
ResponsesRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseChoice,
|
||||
@ -35,3 +38,12 @@ def maybe_filter_parallel_tool_calls(
|
||||
]
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
media_type = content_type.split(";", maxsplit=1)[0]
|
||||
if media_type != "application/json":
|
||||
raise RequestValidationError(
|
||||
errors=["Unsupported Media Type: Only 'application/json' is allowed"]
|
||||
)
|
||||
|
||||
16
vllm/entrypoints/pooling/__init__.py
Normal file
16
vllm/entrypoints/pooling/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def register_pooling_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.pooling.classify.api_router import router as classify_router
|
||||
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
|
||||
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
app.include_router(classify_router)
|
||||
app.include_router(embed_router)
|
||||
app.include_router(score_router)
|
||||
app.include_router(pooling_router)
|
||||
0
vllm/entrypoints/pooling/classify/__init__.py
Normal file
0
vllm/entrypoints/pooling/classify/__init__.py
Normal file
50
vllm/entrypoints/pooling/classify/api_router.py
Normal file
50
vllm/entrypoints/pooling/classify/api_router.py
Normal file
@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def classify(request: Request) -> ServingClassification | None:
|
||||
return request.app.state.openai_serving_classification
|
||||
|
||||
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
181
vllm/entrypoints/pooling/classify/protocol.py
Normal file
181
vllm/entrypoints/pooling/classify/protocol.py
Normal file
@ -0,0 +1,181 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
)
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
input: list[str] | str
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:classification-extra-params]
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-classification-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:chat-classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
ClassificationRequest: TypeAlias = (
|
||||
ClassificationCompletionRequest | ClassificationChatRequest
|
||||
)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
index: int
|
||||
label: str | None
|
||||
probs: list[float]
|
||||
num_classes: int
|
||||
|
||||
|
||||
class ClassificationResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ClassificationData]
|
||||
usage: UsageInfo
|
||||
@ -13,11 +13,6 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
@ -27,6 +22,13 @@ from vllm.entrypoints.openai.serving_engine import (
|
||||
ServeContext,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
0
vllm/entrypoints/pooling/embed/__init__.py
Normal file
0
vllm/entrypoints/pooling/embed/__init__.py
Normal file
67
vllm/entrypoints/pooling/embed/api_router.py
Normal file
67
vllm/entrypoints/pooling/embed/api_router.py
Normal file
@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding | None:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_embedding(
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Request,
|
||||
):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Embeddings API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_embedding(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
elif isinstance(generator, EmbeddingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.body,
|
||||
headers={"metadata": generator.metadata},
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
208
vllm/entrypoints/pooling/embed/protocol.py
Normal file
208
vllm/entrypoints/pooling/embed/protocol.py
Normal file
@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str | None = None
|
||||
input: list[int] | list[list[int]] | str | list[str]
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:embedding-extra-params]
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||
)
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:embedding-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:chat-embedding-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||
)
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:chat-embedding-extra-params]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
|
||||
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: list[float] | str
|
||||
|
||||
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||
body: list[bytes]
|
||||
metadata: str
|
||||
media_type: str = "application/octet-stream"
|
||||
@ -13,12 +13,6 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
@ -29,6 +23,14 @@ from vllm.entrypoints.openai.serving_engine import (
|
||||
TextTokensPrompt,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
)
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
0
vllm/entrypoints/pooling/pooling/__init__.py
Normal file
0
vllm/entrypoints/pooling/pooling/__init__.py
Normal file
63
vllm/entrypoints/pooling/pooling/api_router.py
Normal file
63
vllm/entrypoints/pooling/pooling/api_router.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def pooling(request: Request) -> OpenAIServingPooling | None:
|
||||
return request.app.state.openai_serving_pooling
|
||||
|
||||
|
||||
@router.post(
|
||||
"/pooling",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
handler = pooling(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Pooling API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.create_pooling(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
elif isinstance(generator, PoolingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.body,
|
||||
headers={"metadata": generator.metadata},
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
148
vllm/entrypoints/pooling/pooling/protocol.py
Normal file
148
vllm/entrypoints/pooling/pooling/protocol.py
Normal file
@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Generic, TypeAlias, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
)
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
)
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
|
||||
class PoolingCompletionRequest(EmbeddingCompletionRequest):
|
||||
task: PoolingTask | None = None
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"If it is a classify or token_classify task, the default is True; "
|
||||
"for other tasks, this value should be None.",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class PoolingChatRequest(EmbeddingChatRequest):
|
||||
task: PoolingTask | None = None
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"If it is a classify or token_classify task, the default is True; "
|
||||
"for other tasks, this value should be None.",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
model: str | None = None
|
||||
|
||||
priority: int = Field(default=0)
|
||||
"""
|
||||
The priority of the request (lower means earlier handling;
|
||||
default: 0). Any priority other than 0 will raise an error
|
||||
if the served model does not use priority scheduling.
|
||||
"""
|
||||
data: T
|
||||
|
||||
task: PoolingTask = "plugin"
|
||||
encoding_format: EncodingFormat = "float"
|
||||
embed_dtype: EmbedDType = Field(
|
||||
default="float32",
|
||||
description=(
|
||||
"What dtype to use for encoding. Default to using float32 for base64 "
|
||||
"encoding to match the OpenAI python client behavior. "
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
endianness: Endianness = Field(
|
||||
default="native",
|
||||
description=(
|
||||
"What endianness to use for encoding. Default to using native for "
|
||||
"base64 encoding to match the OpenAI python client behavior."
|
||||
"This parameter will affect base64 and binary_response."
|
||||
),
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams()
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
request_id: str | None = None
|
||||
"""
|
||||
The request_id associated with this response
|
||||
"""
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual output is generated
|
||||
by the plugin itself. Hence, we use a generic type for the response data
|
||||
"""
|
||||
|
||||
|
||||
PoolingRequest: TypeAlias = (
|
||||
PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
|
||||
)
|
||||
|
||||
|
||||
class PoolingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "pooling"
|
||||
data: list[list[float]] | list[float] | str
|
||||
|
||||
|
||||
class PoolingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[PoolingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class PoolingBytesResponse(OpenAIBaseModel):
|
||||
body: list[bytes]
|
||||
metadata: str
|
||||
media_type: str = "application/octet-stream"
|
||||
@ -16,6 +16,11 @@ from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
@ -24,10 +29,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
PoolingResponseData,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
0
vllm/entrypoints/pooling/score/__init__.py
Normal file
0
vllm/entrypoints/pooling/score/__init__.py
Normal file
149
vllm/entrypoints/pooling/score/api_router.py
Normal file
149
vllm/entrypoints/pooling/score/api_router.py
Normal file
@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def score(request: Request) -> ServingScores | None:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> ServingScores | None:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
@router.post(
|
||||
"/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Score API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/score",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||
"have moved it to `/score`. Please update your client accordingly."
|
||||
)
|
||||
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Rerank (Score) API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning_once(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client "
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)"
|
||||
)
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/rerank",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
145
vllm/entrypoints/pooling/score/protocol.py
Normal file
145
vllm/entrypoints/pooling/score/protocol.py
Normal file
@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
)
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.openai.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
text_1: list[str] | str | ScoreMultiModalParam
|
||||
text_2: list[str] | str | ScoreMultiModalParam
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:score-extra-params]
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
query: str | ScoreMultiModalParam
|
||||
documents: list[str] | ScoreMultiModalParam
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:rerank-extra-params]
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
activation: bool | None = Field(
|
||||
default=None,
|
||||
description="activation will be deprecated, please use use_activation instead.",
|
||||
)
|
||||
|
||||
use_activation: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to use activation for classification outputs. "
|
||||
"Default is True.",
|
||||
)
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=get_use_activation(self),
|
||||
)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str | None = None
|
||||
multi_modal: ScoreContentPartParam | None = None
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
index: int
|
||||
document: RerankDocument
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RerankUsage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankResponse(OpenAIBaseModel):
|
||||
id: str
|
||||
model: str
|
||||
usage: RerankUsage
|
||||
results: list[RerankResult]
|
||||
|
||||
|
||||
class ScoreResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "score"
|
||||
score: float
|
||||
|
||||
|
||||
class ScoreResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: list[ScoreResponseData]
|
||||
usage: UsageInfo
|
||||
@ -11,6 +11,11 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankDocument,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
@ -19,10 +24,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.score_utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
@ -9,12 +11,56 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
INVOCATION_VALIDATORS,
|
||||
base,
|
||||
chat,
|
||||
completion,
|
||||
create_chat_completion,
|
||||
create_completion,
|
||||
health,
|
||||
validate_json_request,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
|
||||
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
|
||||
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
|
||||
from vllm.entrypoints.pooling.score.api_router import (
|
||||
create_score,
|
||||
do_rerank,
|
||||
rerank,
|
||||
score,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||
(CompletionRequest, (completion, create_completion)),
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
(ClassificationRequest, (classify, create_classify)),
|
||||
(ScoreRequest, (score, create_score)),
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
(PoolingRequest, (pooling, create_pooling)),
|
||||
]
|
||||
|
||||
# NOTE: Construct the TypeAdapters only once
|
||||
INVOCATION_VALIDATORS = [
|
||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||
]
|
||||
|
||||
|
||||
def register_sagemaker_routes(router: APIRouter):
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user