mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[Refactor] [1/N] to simplify the vLLM serving architecture (#28040)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
69520bc695
commit
3f42b05fbc
@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_engine_dead_error():
|
||||
# Import the health function directly to test it in isolation
|
||||
from vllm.entrypoints.openai.api_server import health
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
|
||||
# Create a mock request that simulates what FastAPI would provide
|
||||
mock_request = Mock(spec=Request)
|
||||
|
||||
@ -118,6 +118,7 @@ async def init_app(
|
||||
)
|
||||
)
|
||||
app.state.engine_client = engine
|
||||
app.state.args = args
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@ -20,21 +20,15 @@ from http import HTTPStatus
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
||||
from starlette.routing import Mount
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -56,17 +50,11 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
StreamingResponsesResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
@ -80,8 +68,6 @@ from vllm.entrypoints.openai.serving_models import (
|
||||
OpenAIServingModels,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
ScalingMiddleware,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
@ -109,8 +100,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args(
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PrometheusResponse(Response):
|
||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
# instead of the default "application/json" which is incorrect.
|
||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/load",
|
||||
"/ping",
|
||||
"/version",
|
||||
"/server_info",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
try:
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
except EngineDeadError:
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request):
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: bool = Query(True),
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
aborts any in-flight requests immediately.
|
||||
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
clear_cache=clear_cache,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
return JSONResponse(
|
||||
content={"error": str(err)},
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to pause generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to pause generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||
"""Resume generation after a pause."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
return JSONResponse(
|
||||
content={"status": "resumed"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to resume generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to resume generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/is_paused")
|
||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||
"""Return the current pause status."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
paused = await engine.is_paused()
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to fetch pause status")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch pause status: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
return JSONResponse(content={"is_paused": paused})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
except NotImplementedError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/detokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise RequestValidationError(errors=[str(e)]) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def maybe_register_tokenizer_info_endpoint(args):
|
||||
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||
if getattr(args, "enable_tokenizer_info_endpoint", False):
|
||||
|
||||
@router.get("/tokenizer_info")
|
||||
async def get_tokenizer_info(raw_request: Request):
|
||||
"""Get comprehensive tokenizer information."""
|
||||
result = await tokenization(raw_request).get_tokenizer_info()
|
||||
return JSONResponse(
|
||||
content=result.model_dump(),
|
||||
status_code=result.error.code
|
||||
if isinstance(result, ErrorResponse)
|
||||
else 200,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
handler = models(raw_request)
|
||||
@ -898,33 +683,6 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
await engine_client(raw_request).reset_mm_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/sleep")
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
await engine_client(raw_request).sleep(int(level))
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/wake_up")
|
||||
async def wake_up(raw_request: Request):
|
||||
tags = raw_request.query_params.getlist("tags")
|
||||
if tags == []:
|
||||
# set to None to wake up all tags if no tags are provided
|
||||
tags = None
|
||||
logger.info("wake up the engine with tags: %s", tags)
|
||||
await engine_client(raw_request).wake_up(tags)
|
||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.get("/is_sleeping")
|
||||
async def is_sleeping(raw_request: Request):
|
||||
logger.info("check whether the engine is sleeping")
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
@router.post("/collective_rpc")
|
||||
async def collective_rpc(raw_request: Request):
|
||||
try:
|
||||
@ -952,138 +710,13 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
return Response(status_code=200)
|
||||
response: list[Any] = []
|
||||
for result in results:
|
||||
if result is None or isinstance(result, (dict, list)):
|
||||
if result is None or isinstance(result, dict | list):
|
||||
response.append(result)
|
||||
else:
|
||||
response.append(str(result))
|
||||
return JSONResponse(content={"results": response})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": dict},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size is required"
|
||||
)
|
||||
|
||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size must be a positive integer"
|
||||
)
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="drain_timeout must be a positive integer"
|
||||
)
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
global _scaling_elastic_ep
|
||||
_scaling_elastic_ep = True
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||
}
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
@ -1176,41 +809,6 @@ class XRequestIdMiddleware:
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
# Global variable to track scaling state
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
class ScalingMiddleware:
|
||||
"""
|
||||
Middleware that checks if the model is currently scaling and
|
||||
returns a 503 Service Unavailable response if it is.
|
||||
|
||||
This middleware applies to all HTTP requests and prevents
|
||||
processing when the model is in a scaling state.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] != "http":
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Check global scaling state
|
||||
global _scaling_elastic_ep
|
||||
if _scaling_elastic_ep:
|
||||
# Return 503 Service Unavailable response
|
||||
response = JSONResponse(
|
||||
content={
|
||||
"error": "The model is currently scaling. Please try again later."
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
return response(scope, receive, send)
|
||||
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.state.args = args
|
||||
from vllm.entrypoints.serve import register_vllm_serve_api_routers
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
|
||||
|
||||
register_dynamic_lora_routes(router)
|
||||
register_vllm_serve_api_routers(app)
|
||||
|
||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||
|
||||
@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
app.root_path = args.root_path
|
||||
|
||||
mount_metrics(app)
|
||||
|
||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||
|
||||
register_pooling_api_routers(app)
|
||||
@ -1462,31 +1053,6 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
|
||||
app = sagemaker_standards.bootstrap(app)
|
||||
# Optional endpoints
|
||||
if args.tokens_only:
|
||||
|
||||
@app.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
return app
|
||||
|
||||
@ -1515,7 +1081,7 @@ async def init_app_state(
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
|
||||
state.args = args
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
@ -1839,7 +1405,6 @@ async def run_server_worker(
|
||||
args,
|
||||
client_config=client_config,
|
||||
) as engine_client:
|
||||
maybe_register_tokenizer_info_endpoint(args)
|
||||
app = build_app(args)
|
||||
|
||||
await init_app_state(engine_client, app.state, args)
|
||||
|
||||
@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
ResponsesRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
|
||||
@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
|
||||
completion,
|
||||
create_chat_completion,
|
||||
create_completion,
|
||||
health,
|
||||
validate_json_request,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
|
||||
score,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
|
||||
60
vllm/entrypoints/serve/__init__.py
Normal file
60
vllm/entrypoints/serve/__init__.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def register_vllm_serve_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.serve.lora.api_router import (
|
||||
attach_router as attach_lora_router,
|
||||
)
|
||||
|
||||
attach_lora_router(app)
|
||||
from vllm.entrypoints.serve.elastic_ep.api_router import (
|
||||
attach_router as attach_elastic_ep_router,
|
||||
)
|
||||
|
||||
attach_elastic_ep_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.profile.api_router import (
|
||||
attach_router as attach_profile_router,
|
||||
)
|
||||
|
||||
attach_profile_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.sleep.api_router import (
|
||||
attach_router as attach_sleep_router,
|
||||
)
|
||||
|
||||
attach_sleep_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.tokenize.api_router import (
|
||||
attach_router as attach_tokenize_router,
|
||||
)
|
||||
|
||||
attach_tokenize_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.disagg.api_router import (
|
||||
attach_router as attach_disagg_router,
|
||||
)
|
||||
|
||||
attach_disagg_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.rlhf.api_router import (
|
||||
attach_router as attach_rlhf_router,
|
||||
)
|
||||
|
||||
attach_rlhf_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.instrumentator.metrics import (
|
||||
attach_router as attach_metrics_router,
|
||||
)
|
||||
|
||||
attach_metrics_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.instrumentator.health import (
|
||||
attach_router as attach_health_router,
|
||||
)
|
||||
|
||||
attach_health_router(app)
|
||||
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
110
vllm/entrypoints/serve/disagg/api_router.py
Normal file
110
vllm/entrypoints/serve/disagg/api_router.py
Normal file
@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.serving import (
|
||||
ServingTokens,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return tokenization(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if getattr(app.state.args, "tokens_only", False):
|
||||
|
||||
@router.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
app.include_router(router)
|
||||
90
vllm/entrypoints/serve/disagg/protocol.py
Normal file
90
vllm/entrypoints/serve/disagg/protocol.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProbs,
|
||||
Logprob,
|
||||
SamplingParams,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
model: str | None = None
|
||||
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
|
||||
class GenerateResponseChoice(BaseModel):
|
||||
index: int
|
||||
logprobs: ChatCompletionLogProbs | None = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: str | None = "stop"
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
choices: list[GenerateResponseChoice]
|
||||
|
||||
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
get_scaling_elastic_ep,
|
||||
set_scaling_elastic_ep,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": dict},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size is required"
|
||||
)
|
||||
|
||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="new_data_parallel_size must be a positive integer",
|
||||
)
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="drain_timeout must be a positive integer"
|
||||
)
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
set_scaling_elastic_ep(True)
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||
}
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
set_scaling_elastic_ep(False)
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
# Global variable to track scaling state
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
def get_scaling_elastic_ep():
|
||||
return _scaling_elastic_ep
|
||||
|
||||
|
||||
def set_scaling_elastic_ep(value):
|
||||
global _scaling_elastic_ep
|
||||
_scaling_elastic_ep = value
|
||||
|
||||
|
||||
class ScalingMiddleware:
|
||||
"""
|
||||
Middleware that checks if the model is currently scaling and
|
||||
returns a 503 Service Unavailable response if it is.
|
||||
|
||||
This middleware applies to all HTTP requests and prevents
|
||||
processing when the model is in a scaling state.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] != "http":
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Check global scaling state
|
||||
if get_scaling_elastic_ep():
|
||||
# Return 503 Service Unavailable response
|
||||
response = JSONResponse(
|
||||
content={
|
||||
"error": "The model is currently scaling. Please try again later."
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
return response(scope, receive, send)
|
||||
|
||||
return self.app(scope, receive, send)
|
||||
0
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
0
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
33
vllm/entrypoints/serve/instrumentator/health.py
Normal file
33
vllm/entrypoints/serve/instrumentator/health.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
try:
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
except EngineDeadError:
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
def attach_router(app):
|
||||
app.include_router(router)
|
||||
46
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
46
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import re
|
||||
|
||||
import prometheus_client
|
||||
from fastapi import FastAPI, Response
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.routing import Mount
|
||||
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
|
||||
|
||||
class PrometheusResponse(Response):
|
||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
# instead of the default "application/json" which is incorrect.
|
||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/load",
|
||||
"/ping",
|
||||
"/version",
|
||||
"/server_info",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.api_server import models, validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def register_dynamic_lora_routes(router: APIRouter):
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
|
||||
return
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
|
||||
@sagemaker_standards.register_load_adapter_handler(
|
||||
request_shape={
|
||||
"lora_name": "body.name",
|
||||
@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
return router
|
||||
# register the router
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
49
vllm/entrypoints/serve/profile/api_router.py
Normal file
49
vllm/entrypoints/serve/profile/api_router.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
logger.warning_once(
|
||||
"CUDA Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!"
|
||||
)
|
||||
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
102
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
102
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: bool = Query(True),
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
aborts any in-flight requests immediately.
|
||||
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
clear_cache=clear_cache,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
return JSONResponse(
|
||||
content={"error": str(err)},
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to pause generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to pause generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||
"""Resume generation after a pause."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
return JSONResponse(
|
||||
content={"status": "resumed"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to resume generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to resume generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/is_paused")
|
||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||
"""Return the current pause status."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
paused = await engine.is_paused()
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to fetch pause status")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch pause status: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
return JSONResponse(content={"is_paused": paused})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
60
vllm/entrypoints/serve/sleep/api_router.py
Normal file
60
vllm/entrypoints/serve/sleep/api_router.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sleep")
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
await engine_client(raw_request).sleep(int(level))
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/wake_up")
|
||||
async def wake_up(raw_request: Request):
|
||||
tags = raw_request.query_params.getlist("tags")
|
||||
if tags == []:
|
||||
# set to None to wake up all tags if no tags are provided
|
||||
tags = None
|
||||
logger.info("wake up the engine with tags: %s", tags)
|
||||
await engine_client(raw_request).wake_up(tags)
|
||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/is_sleeping")
|
||||
async def is_sleeping(raw_request: Request):
|
||||
logger.info("check whether the engine is sleeping")
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
logger.warning(
|
||||
"SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!"
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
118
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
118
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
except NotImplementedError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/detokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise RequestValidationError(errors=[str(e)]) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if getattr(app.state.args, "enable_tokenizer_info_endpoint", False):
|
||||
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||
|
||||
@router.get("/tokenizer_info")
|
||||
async def get_tokenizer_info(raw_request: Request):
|
||||
"""Get comprehensive tokenizer information."""
|
||||
result = await tokenization(raw_request).get_tokenizer_info()
|
||||
return JSONResponse(
|
||||
content=result.model_dump(),
|
||||
status_code=result.error.code
|
||||
if isinstance(result, ErrorResponse)
|
||||
else 200,
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
Loading…
x
Reference in New Issue
Block a user