mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:25:39 +08:00
[Refactor] [1/N] to simplify the vLLM serving architecture (#28040)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
69520bc695
commit
3f42b05fbc
@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check_engine_dead_error():
|
async def test_health_check_engine_dead_error():
|
||||||
# Import the health function directly to test it in isolation
|
# Import the health function directly to test it in isolation
|
||||||
from vllm.entrypoints.openai.api_server import health
|
from vllm.entrypoints.serve.instrumentator.health import health
|
||||||
|
|
||||||
# Create a mock request that simulates what FastAPI would provide
|
# Create a mock request that simulates what FastAPI would provide
|
||||||
mock_request = Mock(spec=Request)
|
mock_request = Mock(spec=Request)
|
||||||
|
|||||||
@ -118,6 +118,7 @@ async def init_app(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
app.state.engine_client = engine
|
app.state.engine_client = engine
|
||||||
|
app.state.args = args
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -20,21 +20,15 @@ from http import HTTPStatus
|
|||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
import prometheus_client
|
|
||||||
import pydantic
|
import pydantic
|
||||||
import regex as re
|
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
from prometheus_client import make_asgi_app
|
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
|
||||||
from starlette.concurrency import iterate_in_threadpool
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
||||||
from starlette.routing import Mount
|
|
||||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
from typing_extensions import assert_never
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DetokenizeRequest,
|
|
||||||
DetokenizeResponse,
|
|
||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
GenerateRequest,
|
|
||||||
GenerateResponse,
|
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
StreamingResponsesResponse,
|
StreamingResponsesResponse,
|
||||||
TokenizeRequest,
|
|
||||||
TokenizeResponse,
|
|
||||||
TranscriptionRequest,
|
TranscriptionRequest,
|
||||||
TranscriptionResponseVariant,
|
TranscriptionResponseVariant,
|
||||||
TranslationRequest,
|
TranslationRequest,
|
||||||
@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
|
|||||||
OpenAIServingModels,
|
OpenAIServingModels,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
|
||||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
|
||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
OpenAIServingTranscription,
|
OpenAIServingTranscription,
|
||||||
OpenAIServingTranslation,
|
OpenAIServingTranslation,
|
||||||
@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
|||||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||||
|
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||||
|
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||||
|
ScalingMiddleware,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||||
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||||
from vllm.entrypoints.utils import (
|
from vllm.entrypoints.utils import (
|
||||||
cli_env_setup,
|
cli_env_setup,
|
||||||
@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
|||||||
from vllm.utils.gc_utils import freeze_gc_heap
|
from vllm.utils.gc_utils import freeze_gc_heap
|
||||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||||
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
||||||
from vllm.v1.engine.exceptions import EngineDeadError
|
|
||||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class PrometheusResponse(Response):
|
|
||||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
|
||||||
|
|
||||||
|
|
||||||
def mount_metrics(app: FastAPI):
|
|
||||||
"""Mount prometheus metrics to a FastAPI app."""
|
|
||||||
|
|
||||||
registry = get_prometheus_registry()
|
|
||||||
|
|
||||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
|
||||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
|
||||||
# instead of the default "application/json" which is incorrect.
|
|
||||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
|
||||||
Instrumentator(
|
|
||||||
excluded_handlers=[
|
|
||||||
"/metrics",
|
|
||||||
"/health",
|
|
||||||
"/load",
|
|
||||||
"/ping",
|
|
||||||
"/version",
|
|
||||||
"/server_info",
|
|
||||||
],
|
|
||||||
registry=registry,
|
|
||||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
|
||||||
|
|
||||||
# Add prometheus asgi middleware to route /metrics requests
|
|
||||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
|
||||||
|
|
||||||
# Workaround for 307 Redirect for /metrics
|
|
||||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
|
||||||
app.routes.append(metrics_route)
|
|
||||||
|
|
||||||
|
|
||||||
def base(request: Request) -> OpenAIServing:
|
def base(request: Request) -> OpenAIServing:
|
||||||
# Reuse the existing instance
|
# Reuse the existing instance
|
||||||
return tokenization(request)
|
return tokenization(request)
|
||||||
@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
|
|||||||
return request.app.state.serving_tokens
|
return request.app.state.serving_tokens
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health", response_class=Response)
|
|
||||||
async def health(raw_request: Request) -> Response:
|
|
||||||
"""Health check."""
|
|
||||||
try:
|
|
||||||
await engine_client(raw_request).check_health()
|
|
||||||
return Response(status_code=200)
|
|
||||||
except EngineDeadError:
|
|
||||||
return Response(status_code=503)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/load")
|
@router.get("/load")
|
||||||
async def get_server_load_metrics(request: Request):
|
async def get_server_load_metrics(request: Request):
|
||||||
# This endpoint returns the current server load metrics.
|
# This endpoint returns the current server load metrics.
|
||||||
@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
|
|||||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||||
|
|
||||||
|
|
||||||
@router.post("/pause")
|
|
||||||
async def pause_generation(
|
|
||||||
raw_request: Request,
|
|
||||||
wait_for_inflight_requests: bool = Query(False),
|
|
||||||
clear_cache: bool = Query(True),
|
|
||||||
) -> JSONResponse:
|
|
||||||
"""Pause generation requests to allow weight updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
|
||||||
requests to finish before pausing. When ``False`` (default),
|
|
||||||
aborts any in-flight requests immediately.
|
|
||||||
clear_cache: Whether to clear KV/prefix caches after draining.
|
|
||||||
"""
|
|
||||||
|
|
||||||
engine = engine_client(raw_request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await engine.pause_generation(
|
|
||||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
|
||||||
clear_cache=clear_cache,
|
|
||||||
)
|
|
||||||
return JSONResponse(
|
|
||||||
content={"status": "paused"},
|
|
||||||
status_code=HTTPStatus.OK.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as err:
|
|
||||||
return JSONResponse(
|
|
||||||
content={"error": str(err)},
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
||||||
)
|
|
||||||
except Exception as err: # pragma: no cover - defensive
|
|
||||||
logger.exception("Failed to pause generation")
|
|
||||||
return JSONResponse(
|
|
||||||
content={"error": f"Failed to pause generation: {err}"},
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/resume")
|
|
||||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
|
||||||
"""Resume generation after a pause."""
|
|
||||||
|
|
||||||
engine = engine_client(raw_request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await engine.resume_generation()
|
|
||||||
return JSONResponse(
|
|
||||||
content={"status": "resumed"},
|
|
||||||
status_code=HTTPStatus.OK.value,
|
|
||||||
)
|
|
||||||
except Exception as err: # pragma: no cover - defensive
|
|
||||||
logger.exception("Failed to resume generation")
|
|
||||||
return JSONResponse(
|
|
||||||
content={"error": f"Failed to resume generation: {err}"},
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/is_paused")
|
|
||||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
|
||||||
"""Return the current pause status."""
|
|
||||||
|
|
||||||
engine = engine_client(raw_request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
paused = await engine.is_paused()
|
|
||||||
except Exception as err: # pragma: no cover - defensive
|
|
||||||
logger.exception("Failed to fetch pause status")
|
|
||||||
return JSONResponse(
|
|
||||||
content={"error": f"Failed to fetch pause status: {err}"},
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(content={"is_paused": paused})
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/tokenize",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
|
||||||
handler = tokenization(raw_request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
generator = await handler.create_tokenize(request, raw_request)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
elif isinstance(generator, TokenizeResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/detokenize",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
|
||||||
handler = tokenization(raw_request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
generator = await handler.create_detokenize(request, raw_request)
|
|
||||||
except OverflowError as e:
|
|
||||||
raise RequestValidationError(errors=[str(e)]) from e
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
elif isinstance(generator, DetokenizeResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
assert_never(generator)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_register_tokenizer_info_endpoint(args):
|
|
||||||
"""Conditionally register the tokenizer info endpoint if enabled."""
|
|
||||||
if getattr(args, "enable_tokenizer_info_endpoint", False):
|
|
||||||
|
|
||||||
@router.get("/tokenizer_info")
|
|
||||||
async def get_tokenizer_info(raw_request: Request):
|
|
||||||
"""Get comprehensive tokenizer information."""
|
|
||||||
result = await tokenization(raw_request).get_tokenizer_info()
|
|
||||||
return JSONResponse(
|
|
||||||
content=result.model_dump(),
|
|
||||||
status_code=result.error.code
|
|
||||||
if isinstance(result, ErrorResponse)
|
|
||||||
else 200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
async def show_available_models(raw_request: Request):
|
async def show_available_models(raw_request: Request):
|
||||||
handler = models(raw_request)
|
handler = models(raw_request)
|
||||||
@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
await engine_client(raw_request).reset_mm_cache()
|
await engine_client(raw_request).reset_mm_cache()
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
@router.post("/sleep")
|
|
||||||
async def sleep(raw_request: Request):
|
|
||||||
# get POST params
|
|
||||||
level = raw_request.query_params.get("level", "1")
|
|
||||||
await engine_client(raw_request).sleep(int(level))
|
|
||||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
|
||||||
# is sent but does not finish yet when we return a response.
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
@router.post("/wake_up")
|
|
||||||
async def wake_up(raw_request: Request):
|
|
||||||
tags = raw_request.query_params.getlist("tags")
|
|
||||||
if tags == []:
|
|
||||||
# set to None to wake up all tags if no tags are provided
|
|
||||||
tags = None
|
|
||||||
logger.info("wake up the engine with tags: %s", tags)
|
|
||||||
await engine_client(raw_request).wake_up(tags)
|
|
||||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
|
||||||
# is sent but does not finish yet when we return a response.
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
@router.get("/is_sleeping")
|
|
||||||
async def is_sleeping(raw_request: Request):
|
|
||||||
logger.info("check whether the engine is sleeping")
|
|
||||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
|
||||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
|
||||||
|
|
||||||
@router.post("/collective_rpc")
|
@router.post("/collective_rpc")
|
||||||
async def collective_rpc(raw_request: Request):
|
async def collective_rpc(raw_request: Request):
|
||||||
try:
|
try:
|
||||||
@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
response: list[Any] = []
|
response: list[Any] = []
|
||||||
for result in results:
|
for result in results:
|
||||||
if result is None or isinstance(result, (dict, list)):
|
if result is None or isinstance(result, dict | list):
|
||||||
response.append(result)
|
response.append(result)
|
||||||
else:
|
else:
|
||||||
response.append(str(result))
|
response.append(str(result))
|
||||||
return JSONResponse(content={"results": response})
|
return JSONResponse(content={"results": response})
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/scale_elastic_ep",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.OK.value: {"model": dict},
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def scale_elastic_ep(raw_request: Request):
|
|
||||||
try:
|
|
||||||
body = await raw_request.json()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
|
||||||
|
|
||||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
|
||||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
|
||||||
|
|
||||||
if new_data_parallel_size is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="new_data_parallel_size is required"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="new_data_parallel_size must be a positive integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="drain_timeout must be a positive integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set scaling flag to prevent new requests
|
|
||||||
global _scaling_elastic_ep
|
|
||||||
_scaling_elastic_ep = True
|
|
||||||
client = engine_client(raw_request)
|
|
||||||
try:
|
|
||||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except TimeoutError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=408,
|
|
||||||
detail="Scale failed due to request drain timeout "
|
|
||||||
f"after {drain_timeout} seconds",
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Scale failed: %s", e)
|
|
||||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
|
||||||
finally:
|
|
||||||
_scaling_elastic_ep = False
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/is_scaling_elastic_ep")
|
|
||||||
async def is_scaling_elastic_ep(raw_request: Request):
|
|
||||||
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/inference/v1/generate",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def generate(request: GenerateRequest, raw_request: Request):
|
|
||||||
handler = generate_tokens(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support generate tokens API"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
generator = await handler.serve_tokens(request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(generator, GenerateResponse):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
||||||
logger.warning_once(
|
|
||||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
|
||||||
"used for local development!"
|
|
||||||
)
|
|
||||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
|
||||||
logger.warning_once(
|
|
||||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
|
||||||
"used for local development!"
|
|
||||||
)
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
|
||||||
|
|
||||||
@router.post("/start_profile")
|
|
||||||
async def start_profile(raw_request: Request):
|
|
||||||
logger.info("Starting profiler...")
|
|
||||||
await engine_client(raw_request).start_profile()
|
|
||||||
logger.info("Profiler started.")
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
@router.post("/stop_profile")
|
|
||||||
async def stop_profile(raw_request: Request):
|
|
||||||
logger.info("Stopping profiler...")
|
|
||||||
await engine_client(raw_request).stop_profile()
|
|
||||||
logger.info("Profiler stopped.")
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||||
if not log_config_file:
|
if not log_config_file:
|
||||||
return None
|
return None
|
||||||
@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
|
|||||||
return self.app(scope, receive, send_with_request_id)
|
return self.app(scope, receive, send_with_request_id)
|
||||||
|
|
||||||
|
|
||||||
# Global variable to track scaling state
|
|
||||||
_scaling_elastic_ep = False
|
|
||||||
|
|
||||||
|
|
||||||
class ScalingMiddleware:
|
|
||||||
"""
|
|
||||||
Middleware that checks if the model is currently scaling and
|
|
||||||
returns a 503 Service Unavailable response if it is.
|
|
||||||
|
|
||||||
This middleware applies to all HTTP requests and prevents
|
|
||||||
processing when the model is in a scaling state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
self.app = app
|
|
||||||
|
|
||||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
|
||||||
if scope["type"] != "http":
|
|
||||||
return self.app(scope, receive, send)
|
|
||||||
|
|
||||||
# Check global scaling state
|
|
||||||
global _scaling_elastic_ep
|
|
||||||
if _scaling_elastic_ep:
|
|
||||||
# Return 503 Service Unavailable response
|
|
||||||
response = JSONResponse(
|
|
||||||
content={
|
|
||||||
"error": "The model is currently scaling. Please try again later."
|
|
||||||
},
|
|
||||||
status_code=503,
|
|
||||||
)
|
|
||||||
return response(scope, receive, send)
|
|
||||||
|
|
||||||
return self.app(scope, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||||
"""Extract content from a streaming response chunk."""
|
"""Extract content from a streaming response chunk."""
|
||||||
try:
|
try:
|
||||||
@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.state.args = args
|
||||||
|
from vllm.entrypoints.serve import register_vllm_serve_api_routers
|
||||||
|
|
||||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
register_vllm_serve_api_routers(app)
|
||||||
logger.warning(
|
|
||||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
|
||||||
"This should ONLY be used for local development!"
|
|
||||||
)
|
|
||||||
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
|
|
||||||
|
|
||||||
register_dynamic_lora_routes(router)
|
|
||||||
|
|
||||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||||
|
|
||||||
@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
|
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
|
|
||||||
mount_metrics(app)
|
|
||||||
|
|
||||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||||
|
|
||||||
register_pooling_api_routers(app)
|
register_pooling_api_routers(app)
|
||||||
@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
app = sagemaker_standards.bootstrap(app)
|
app = sagemaker_standards.bootstrap(app)
|
||||||
# Optional endpoints
|
|
||||||
if args.tokens_only:
|
|
||||||
|
|
||||||
@app.post("/abort_requests")
|
|
||||||
async def abort_requests(raw_request: Request):
|
|
||||||
"""
|
|
||||||
Abort one or more requests. To be used in a
|
|
||||||
Disaggregated Everything setup.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
body = await raw_request.json()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
||||||
detail=f"JSON decode error: {e}",
|
|
||||||
) from e
|
|
||||||
request_ids = body.get("request_ids")
|
|
||||||
if request_ids is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
||||||
detail="Missing 'request_ids' in request body",
|
|
||||||
)
|
|
||||||
# Abort requests in background
|
|
||||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@ -1515,7 +1081,7 @@ async def init_app_state(
|
|||||||
state.engine_client = engine_client
|
state.engine_client = engine_client
|
||||||
state.log_stats = not args.disable_log_stats
|
state.log_stats = not args.disable_log_stats
|
||||||
state.vllm_config = vllm_config
|
state.vllm_config = vllm_config
|
||||||
|
state.args = args
|
||||||
supported_tasks = await engine_client.get_supported_tasks()
|
supported_tasks = await engine_client.get_supported_tasks()
|
||||||
logger.info("Supported tasks: %s", supported_tasks)
|
logger.info("Supported tasks: %s", supported_tasks)
|
||||||
|
|
||||||
@ -1839,7 +1405,6 @@ async def run_server_worker(
|
|||||||
args,
|
args,
|
||||||
client_config=client_config,
|
client_config=client_config,
|
||||||
) as engine_client:
|
) as engine_client:
|
||||||
maybe_register_tokenizer_info_endpoint(args)
|
|
||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
await init_app_state(engine_client, app.state, args)
|
await init_app_state(engine_client, app.state, args)
|
||||||
|
|||||||
@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
GenerateRequest,
|
|
||||||
GenerateResponse,
|
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest,
|
||||||
@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||||
|
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
|
|||||||
completion,
|
completion,
|
||||||
create_chat_completion,
|
create_chat_completion,
|
||||||
create_completion,
|
create_completion,
|
||||||
health,
|
|
||||||
validate_json_request,
|
validate_json_request,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
|
|||||||
score,
|
score,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||||
|
from vllm.entrypoints.serve.instrumentator.health import health
|
||||||
|
|
||||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||||
# (requires typing_extensions >= 4.13)
|
# (requires typing_extensions >= 4.13)
|
||||||
|
|||||||
60
vllm/entrypoints/serve/__init__.py
Normal file
60
vllm/entrypoints/serve/__init__.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
def register_vllm_serve_api_routers(app: FastAPI):
|
||||||
|
from vllm.entrypoints.serve.lora.api_router import (
|
||||||
|
attach_router as attach_lora_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_lora_router(app)
|
||||||
|
from vllm.entrypoints.serve.elastic_ep.api_router import (
|
||||||
|
attach_router as attach_elastic_ep_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_elastic_ep_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.profile.api_router import (
|
||||||
|
attach_router as attach_profile_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_profile_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.sleep.api_router import (
|
||||||
|
attach_router as attach_sleep_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_sleep_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.tokenize.api_router import (
|
||||||
|
attach_router as attach_tokenize_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_tokenize_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.disagg.api_router import (
|
||||||
|
attach_router as attach_disagg_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_disagg_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.rlhf.api_router import (
|
||||||
|
attach_router as attach_rlhf_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_rlhf_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.instrumentator.metrics import (
|
||||||
|
attach_router as attach_metrics_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_metrics_router(app)
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.instrumentator.health import (
|
||||||
|
attach_router as attach_health_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
attach_health_router(app)
|
||||||
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
110
vllm/entrypoints/serve/disagg/api_router.py
Normal file
110
vllm/entrypoints/serve/disagg/api_router.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ErrorResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.serve.disagg.protocol import (
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.serve.disagg.serving import (
|
||||||
|
ServingTokens,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||||
|
from vllm.entrypoints.utils import (
|
||||||
|
load_aware_call,
|
||||||
|
with_cancellation,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||||
|
return request.app.state.serving_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/inference/v1/generate",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def generate(request: GenerateRequest, raw_request: Request):
|
||||||
|
handler = generate_tokens(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return tokenization(raw_request).create_error_response(
|
||||||
|
message="The model does not support generate tokens API"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
generator = await handler.serve_tokens(request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(generator, GenerateResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
if getattr(app.state.args, "tokens_only", False):
|
||||||
|
|
||||||
|
@router.post("/abort_requests")
|
||||||
|
async def abort_requests(raw_request: Request):
|
||||||
|
"""
|
||||||
|
Abort one or more requests. To be used in a
|
||||||
|
Disaggregated Everything setup.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await raw_request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail=f"JSON decode error: {e}",
|
||||||
|
) from e
|
||||||
|
request_ids = body.get("request_ids")
|
||||||
|
if request_ids is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail="Missing 'request_ids' in request body",
|
||||||
|
)
|
||||||
|
# Abort requests in background
|
||||||
|
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
90
vllm/entrypoints/serve/disagg/protocol.py
Normal file
90
vllm/entrypoints/serve/disagg/protocol.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionLogProbs,
|
||||||
|
Logprob,
|
||||||
|
SamplingParams,
|
||||||
|
StreamOptions,
|
||||||
|
)
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
|
####### Tokens IN <> Tokens OUT #######
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{random_uuid()}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
token_ids: list[int]
|
||||||
|
"""The token ids to generate text from."""
|
||||||
|
|
||||||
|
# features: MultiModalFeatureSpec
|
||||||
|
# TODO (NickLucche): implement once Renderer work is completed
|
||||||
|
features: str | None = None
|
||||||
|
"""The processed MM inputs for the model."""
|
||||||
|
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
"""The sampling parameters for the model."""
|
||||||
|
|
||||||
|
model: str | None = None
|
||||||
|
|
||||||
|
stream: bool | None = False
|
||||||
|
stream_options: StreamOptions | None = None
|
||||||
|
cache_salt: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the prefix cache will be salted with the provided "
|
||||||
|
"string to prevent an attacker to guess prompts in multi-user "
|
||||||
|
"environments. The salt should be random, protected from "
|
||||||
|
"access by 3rd parties, and long enough to be "
|
||||||
|
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||||
|
"to 256 bit)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
kv_transfer_params: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="KVTransfer parameters used for disaggregated serving.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
logprobs: ChatCompletionLogProbs | None = None
|
||||||
|
# per OpenAI spec this is the default
|
||||||
|
finish_reason: str | None = "stop"
|
||||||
|
token_ids: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResponse(BaseModel):
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"{random_uuid()}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
choices: list[GenerateResponseChoice]
|
||||||
|
|
||||||
|
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||||
|
|
||||||
|
kv_transfer_params: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="KVTransfer parameters used for disaggregated serving.",
|
||||||
|
)
|
||||||
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionLogProbs,
|
ChatCompletionLogProbs,
|
||||||
ChatCompletionLogProbsContent,
|
ChatCompletionLogProbsContent,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
GenerateRequest,
|
|
||||||
GenerateResponse,
|
|
||||||
GenerateResponseChoice,
|
|
||||||
PromptTokenUsageInfo,
|
PromptTokenUsageInfo,
|
||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.serve.disagg.protocol import (
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateResponse,
|
||||||
|
GenerateResponseChoice,
|
||||||
|
)
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ErrorResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||||
|
get_scaling_elastic_ep,
|
||||||
|
set_scaling_elastic_ep,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/scale_elastic_ep",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.OK.value: {"model": dict},
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def scale_elastic_ep(raw_request: Request):
|
||||||
|
try:
|
||||||
|
body = await raw_request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||||
|
|
||||||
|
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||||
|
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||||
|
|
||||||
|
if new_data_parallel_size is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="new_data_parallel_size is required"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="new_data_parallel_size must be a positive integer",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="drain_timeout must be a positive integer"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set scaling flag to prevent new requests
|
||||||
|
set_scaling_elastic_ep(True)
|
||||||
|
client = engine_client(raw_request)
|
||||||
|
try:
|
||||||
|
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=408,
|
||||||
|
detail="Scale failed due to request drain timeout "
|
||||||
|
f"after {drain_timeout} seconds",
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Scale failed: %s", e)
|
||||||
|
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||||
|
finally:
|
||||||
|
set_scaling_elastic_ep(False)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/is_scaling_elastic_ep")
|
||||||
|
async def is_scaling_elastic_ep(raw_request: Request):
|
||||||
|
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
app.include_router(router)
|
||||||
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Awaitable
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
# Global variable to track scaling state
|
||||||
|
_scaling_elastic_ep = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaling_elastic_ep():
|
||||||
|
return _scaling_elastic_ep
|
||||||
|
|
||||||
|
|
||||||
|
def set_scaling_elastic_ep(value):
|
||||||
|
global _scaling_elastic_ep
|
||||||
|
_scaling_elastic_ep = value
|
||||||
|
|
||||||
|
|
||||||
|
class ScalingMiddleware:
|
||||||
|
"""
|
||||||
|
Middleware that checks if the model is currently scaling and
|
||||||
|
returns a 503 Service Unavailable response if it is.
|
||||||
|
|
||||||
|
This middleware applies to all HTTP requests and prevents
|
||||||
|
processing when the model is in a scaling state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||||
|
if scope["type"] != "http":
|
||||||
|
return self.app(scope, receive, send)
|
||||||
|
|
||||||
|
# Check global scaling state
|
||||||
|
if get_scaling_elastic_ep():
|
||||||
|
# Return 503 Service Unavailable response
|
||||||
|
response = JSONResponse(
|
||||||
|
content={
|
||||||
|
"error": "The model is currently scaling. Please try again later."
|
||||||
|
},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
return response(scope, receive, send)
|
||||||
|
|
||||||
|
return self.app(scope, receive, send)
|
||||||
0
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
0
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
33
vllm/entrypoints/serve/instrumentator/health.py
Normal file
33
vllm/entrypoints/serve/instrumentator/health.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health", response_class=Response)
|
||||||
|
async def health(raw_request: Request) -> Response:
|
||||||
|
"""Health check."""
|
||||||
|
try:
|
||||||
|
await engine_client(raw_request).check_health()
|
||||||
|
return Response(status_code=200)
|
||||||
|
except EngineDeadError:
|
||||||
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app):
|
||||||
|
app.include_router(router)
|
||||||
46
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
46
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import prometheus_client
|
||||||
|
from fastapi import FastAPI, Response
|
||||||
|
from prometheus_client import make_asgi_app
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
from starlette.routing import Mount
|
||||||
|
|
||||||
|
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||||
|
|
||||||
|
|
||||||
|
class PrometheusResponse(Response):
|
||||||
|
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
"""Mount prometheus metrics to a FastAPI app."""
|
||||||
|
|
||||||
|
registry = get_prometheus_registry()
|
||||||
|
|
||||||
|
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||||
|
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||||
|
# instead of the default "application/json" which is incorrect.
|
||||||
|
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||||
|
Instrumentator(
|
||||||
|
excluded_handlers=[
|
||||||
|
"/metrics",
|
||||||
|
"/health",
|
||||||
|
"/load",
|
||||||
|
"/ping",
|
||||||
|
"/version",
|
||||||
|
"/server_info",
|
||||||
|
],
|
||||||
|
registry=registry,
|
||||||
|
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||||
|
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||||
|
|
||||||
|
# Workaround for 307 Redirect for /metrics
|
||||||
|
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||||
|
app.routes.append(metrics_route)
|
||||||
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
@ -1,9 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.entrypoints.openai.api_server import models, validate_json_request
|
from vllm.entrypoints.openai.api_server import models, validate_json_request
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def register_dynamic_lora_routes(router: APIRouter):
|
def attach_router(app: FastAPI):
|
||||||
|
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
|
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
|
||||||
|
return
|
||||||
|
logger.warning(
|
||||||
|
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||||
|
"This should ONLY be used for local development!"
|
||||||
|
)
|
||||||
|
|
||||||
@sagemaker_standards.register_load_adapter_handler(
|
@sagemaker_standards.register_load_adapter_handler(
|
||||||
request_shape={
|
request_shape={
|
||||||
"lora_name": "body.name",
|
"lora_name": "body.name",
|
||||||
@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
|
|||||||
|
|
||||||
return Response(status_code=200, content=response)
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
return router
|
# register the router
|
||||||
|
app.include_router(router)
|
||||||
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
49
vllm/entrypoints/serve/profile/api_router.py
Normal file
49
vllm/entrypoints/serve/profile/api_router.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start_profile")
|
||||||
|
async def start_profile(raw_request: Request):
|
||||||
|
logger.info("Starting profiler...")
|
||||||
|
await engine_client(raw_request).start_profile()
|
||||||
|
logger.info("Profiler started.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stop_profile")
|
||||||
|
async def stop_profile(raw_request: Request):
|
||||||
|
logger.info("Stopping profiler...")
|
||||||
|
await engine_client(raw_request).stop_profile()
|
||||||
|
logger.info("Profiler stopped.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
logger.warning_once(
|
||||||
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
|
"used for local development!"
|
||||||
|
)
|
||||||
|
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||||
|
logger.warning_once(
|
||||||
|
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||||
|
"used for local development!"
|
||||||
|
)
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||||
|
app.include_router(router)
|
||||||
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
102
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
102
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI, Query, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/pause")
|
||||||
|
async def pause_generation(
|
||||||
|
raw_request: Request,
|
||||||
|
wait_for_inflight_requests: bool = Query(False),
|
||||||
|
clear_cache: bool = Query(True),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Pause generation requests to allow weight updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||||
|
requests to finish before pausing. When ``False`` (default),
|
||||||
|
aborts any in-flight requests immediately.
|
||||||
|
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||||
|
"""
|
||||||
|
|
||||||
|
engine = engine_client(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await engine.pause_generation(
|
||||||
|
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||||
|
clear_cache=clear_cache,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
content={"status": "paused"},
|
||||||
|
status_code=HTTPStatus.OK.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as err:
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": str(err)},
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
)
|
||||||
|
except Exception as err: # pragma: no cover - defensive
|
||||||
|
logger.exception("Failed to pause generation")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Failed to pause generation: {err}"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/resume")
|
||||||
|
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||||
|
"""Resume generation after a pause."""
|
||||||
|
|
||||||
|
engine = engine_client(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await engine.resume_generation()
|
||||||
|
return JSONResponse(
|
||||||
|
content={"status": "resumed"},
|
||||||
|
status_code=HTTPStatus.OK.value,
|
||||||
|
)
|
||||||
|
except Exception as err: # pragma: no cover - defensive
|
||||||
|
logger.exception("Failed to resume generation")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Failed to resume generation: {err}"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/is_paused")
|
||||||
|
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||||
|
"""Return the current pause status."""
|
||||||
|
|
||||||
|
engine = engine_client(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
paused = await engine.is_paused()
|
||||||
|
except Exception as err: # pragma: no cover - defensive
|
||||||
|
logger.exception("Failed to fetch pause status")
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": f"Failed to fetch pause status: {err}"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(content={"is_paused": paused})
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
app.include_router(router)
|
||||||
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
60
vllm/entrypoints/serve/sleep/api_router.py
Normal file
60
vllm/entrypoints/serve/sleep/api_router.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sleep")
|
||||||
|
async def sleep(raw_request: Request):
|
||||||
|
# get POST params
|
||||||
|
level = raw_request.query_params.get("level", "1")
|
||||||
|
await engine_client(raw_request).sleep(int(level))
|
||||||
|
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||||
|
# is sent but does not finish yet when we return a response.
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/wake_up")
|
||||||
|
async def wake_up(raw_request: Request):
|
||||||
|
tags = raw_request.query_params.getlist("tags")
|
||||||
|
if tags == []:
|
||||||
|
# set to None to wake up all tags if no tags are provided
|
||||||
|
tags = None
|
||||||
|
logger.info("wake up the engine with tags: %s", tags)
|
||||||
|
await engine_client(raw_request).wake_up(tags)
|
||||||
|
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||||
|
# is sent but does not finish yet when we return a response.
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/is_sleeping")
|
||||||
|
async def is_sleeping(raw_request: Request):
|
||||||
|
logger.info("check whether the engine is sleeping")
|
||||||
|
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||||
|
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
if not envs.VLLM_SERVER_DEV_MODE:
|
||||||
|
return
|
||||||
|
logger.warning(
|
||||||
|
"SECURITY WARNING: Development endpoints are enabled! "
|
||||||
|
"This should NOT be used in production!"
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
118
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
118
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||||
|
from vllm.entrypoints.utils import (
|
||||||
|
with_cancellation,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/tokenize",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
|
handler = tokenization(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator = await handler.create_tokenize(request, raw_request)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
elif isinstance(generator, TokenizeResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/detokenize",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
|
handler = tokenization(raw_request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator = await handler.create_detokenize(request, raw_request)
|
||||||
|
except OverflowError as e:
|
||||||
|
raise RequestValidationError(errors=[str(e)]) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
elif isinstance(generator, DetokenizeResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
if getattr(app.state.args, "enable_tokenizer_info_endpoint", False):
|
||||||
|
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||||
|
|
||||||
|
@router.get("/tokenizer_info")
|
||||||
|
async def get_tokenizer_info(raw_request: Request):
|
||||||
|
"""Get comprehensive tokenizer information."""
|
||||||
|
result = await tokenization(raw_request).get_tokenizer_info()
|
||||||
|
return JSONResponse(
|
||||||
|
content=result.model_dump(),
|
||||||
|
status_code=result.error.code
|
||||||
|
if isinstance(result, ErrorResponse)
|
||||||
|
else 200,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
Loading…
x
Reference in New Issue
Block a user