mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 17:24:36 +08:00
[Misc] Minimum requirements for SageMaker compatibility (#11576)
This commit is contained in:
parent
5dba257506
commit
68d37809b9
13
Dockerfile
13
Dockerfile
@ -234,8 +234,8 @@ RUN mv vllm test_docs/
|
|||||||
#################### TEST IMAGE ####################
|
#################### TEST IMAGE ####################
|
||||||
|
|
||||||
#################### OPENAI API SERVER ####################
|
#################### OPENAI API SERVER ####################
|
||||||
# openai api server alternative
|
# base openai image with additional requirements, for any subsequent openai-style images
|
||||||
FROM vllm-base AS vllm-openai
|
FROM vllm-base AS vllm-openai-base
|
||||||
|
|
||||||
# install additional dependencies for openai api server
|
# install additional dependencies for openai api server
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
|
|
||||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||||
|
|
||||||
|
# define sagemaker first, so it is not default from `docker build`
|
||||||
|
FROM vllm-openai-base AS vllm-sagemaker
|
||||||
|
|
||||||
|
COPY examples/sagemaker-entrypoint.sh .
|
||||||
|
RUN chmod +x sagemaker-entrypoint.sh
|
||||||
|
ENTRYPOINT ["./sagemaker-entrypoint.sh"]
|
||||||
|
|
||||||
|
FROM vllm-openai-base AS vllm-openai
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||||
#################### OPENAI API SERVER ####################
|
#################### OPENAI API SERVER ####################
|
||||||
|
|||||||
24
examples/sagemaker-entrypoint.sh
Normal file
24
examples/sagemaker-entrypoint.sh
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Define the prefix for environment variables to look for
|
||||||
|
PREFIX="SM_VLLM_"
|
||||||
|
ARG_PREFIX="--"
|
||||||
|
|
||||||
|
# Initialize an array for storing the arguments
|
||||||
|
# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
|
||||||
|
ARGS=(--port 8080)
|
||||||
|
|
||||||
|
# Loop through all environment variables
|
||||||
|
while IFS='=' read -r key value; do
|
||||||
|
# Remove the prefix from the key, convert to lowercase, and replace underscores with dashes
|
||||||
|
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
||||||
|
|
||||||
|
# Add the argument name and value to the ARGS array
|
||||||
|
ARGS+=("${ARG_PREFIX}${arg_name}")
|
||||||
|
if [ -n "$value" ]; then
|
||||||
|
ARGS+=("$value")
|
||||||
|
fi
|
||||||
|
done < <(env | grep "^${PREFIX}")
|
||||||
|
|
||||||
|
# Pass the collected arguments to the main entrypoint
|
||||||
|
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"
|
||||||
@ -16,7 +16,7 @@ from http import HTTPStatus
|
|||||||
from typing import AsyncIterator, Optional, Set, Tuple
|
from typing import AsyncIterator, Optional, Set, Tuple
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, FastAPI, Request
|
from fastapi import APIRouter, FastAPI, HTTPException, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
@ -44,11 +44,15 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
DetokenizeResponse,
|
DetokenizeResponse,
|
||||||
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
EmbeddingResponseData,
|
EmbeddingResponseData,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
|
PoolingChatRequest,
|
||||||
|
PoolingCompletionRequest,
|
||||||
PoolingRequest, PoolingResponse,
|
PoolingRequest, PoolingResponse,
|
||||||
ScoreRequest, ScoreResponse,
|
ScoreRequest, ScoreResponse,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response:
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/ping", methods=["GET", "POST"])
|
||||||
|
async def ping(raw_request: Request) -> Response:
|
||||||
|
"""Ping check. Endpoint required for SageMaker"""
|
||||||
|
return await health(raw_request)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tokenize")
|
@router.post("/tokenize")
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
|||||||
return await create_score(request, raw_request)
|
return await create_score(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
TASK_HANDLERS = {
|
||||||
|
"generate": {
|
||||||
|
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||||
|
"default": (CompletionRequest, create_completion),
|
||||||
|
},
|
||||||
|
"embed": {
|
||||||
|
"messages": (EmbeddingChatRequest, create_embedding),
|
||||||
|
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||||
|
},
|
||||||
|
"score": {
|
||||||
|
"default": (ScoreRequest, create_score),
|
||||||
|
},
|
||||||
|
"reward": {
|
||||||
|
"messages": (PoolingChatRequest, create_pooling),
|
||||||
|
"default": (PoolingCompletionRequest, create_pooling),
|
||||||
|
},
|
||||||
|
"classify": {
|
||||||
|
"messages": (PoolingChatRequest, create_pooling),
|
||||||
|
"default": (PoolingCompletionRequest, create_pooling),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/invocations")
|
||||||
|
async def invocations(raw_request: Request):
|
||||||
|
"""
|
||||||
|
For SageMaker, routes requests to other handlers based on model `task`.
|
||||||
|
"""
|
||||||
|
body = await raw_request.json()
|
||||||
|
task = raw_request.app.state.task
|
||||||
|
|
||||||
|
if task not in TASK_HANDLERS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
||||||
|
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
||||||
|
|
||||||
|
handler_config = TASK_HANDLERS[task]
|
||||||
|
if "messages" in body:
|
||||||
|
request_model, handler = handler_config["messages"]
|
||||||
|
else:
|
||||||
|
request_model, handler = handler_config["default"]
|
||||||
|
|
||||||
|
# this is required since we lose the FastAPI automatic casting
|
||||||
|
request = request_model.model_validate(body)
|
||||||
|
return await handler(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
@ -687,6 +745,7 @@ def init_app_state(
|
|||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
)
|
)
|
||||||
|
state.task = model_config.task
|
||||||
|
|
||||||
|
|
||||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user