diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index e7bf974f13ed..a4ac80070773 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer): extra_headers={ "Content-Type": "application/x-www-form-urlencoded" }) + + +@pytest.mark.parametrize( + "server_args", + [ + pytest.param(["--enable-server-load-tracking"], + id="enable-server-load-tracking") + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_server_load(server: RemoteOpenAIServer): + # Check initial server load + response = requests.get(server.url_for("load")) + assert response.status_code == HTTPStatus.OK + assert response.json().get("server_load") == 0 + + def make_long_completion_request(): + return requests.post( + server.url_for("v1/completions"), + headers={"Content-Type": "application/json"}, + json={ + "prompt": "Give me a long story", + "max_tokens": 1000, + "temperature": 0, + }, + ) + + # Start the completion request in a background thread. + completion_future = asyncio.create_task( + asyncio.to_thread(make_long_completion_request)) + + # Give a short delay to ensure the request has started. + await asyncio.sleep(0.1) + + # Check server load while the completion request is running. + response = requests.get(server.url_for("load")) + assert response.status_code == HTTPStatus.OK + assert response.json().get("server_load") == 1 + + # Wait for the completion request to finish. + await completion_future + await asyncio.sleep(0.1) + + # Check server load after the completion request has finished. + response = requests.get(server.url_for("load")) + assert response.status_code == HTTPStatus.OK + assert response.json().get("server_load") == 0 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7583078e9462..52e65fc214bc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.utils import with_cancellation +from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, @@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response: return Response(status_code=200) +@router.get("/load") +async def get_server_load_metrics(request: Request): + # This endpoint returns the current server load metrics. + # It tracks requests utilizing the GPU from the following routes: + # - /v1/chat/completions + # - /v1/completions + # - /v1/audio/transcriptions + # - /v1/embeddings + # - /pooling + # - /score + # - /v1/score + # - /rerank + # - /v1/rerank + # - /v2/rerank + return JSONResponse( + content={'server_load': request.app.state.server_load_metrics}) + + @router.api_route("/ping", methods=["GET", "POST"]) async def ping(raw_request: Request) -> Response: """Ping check. Endpoint required for SageMaker""" @@ -400,6 +418,7 @@ async def show_version(): @router.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: @@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @router.post("/pooling", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) if handler is None: @@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): @router.post("/score", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: @@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): @router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " @@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): @router.post("/v1/audio/transcriptions") @with_cancellation +@load_aware_call async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()], raw_request: Request): - handler = transcription(raw_request) if handler is None: return base(raw_request).create_error_response( @@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest, @router.post("/rerank", dependencies=[Depends(validate_json_request)]) @with_cancellation +@load_aware_call async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) if handler is None: @@ -894,6 +919,9 @@ async def init_app_state( ) if model_config.runner_type == "transcription" else None state.task = model_config.task + state.enable_server_load_tracking = args.enable_server_load_tracking + state.server_load_metrics = 0 + def create_server_socket(addr: tuple[str, int]) -> socket.socket: family = socket.AF_INET diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index b8cc57430f85..bd66416d90cc 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', default=False, help="If set to True, enable prompt_tokens_details in usage.") + parser.add_argument( + "--enable-server-load-tracking", + action='store_true', + default=False, + help= + "If set to True, enable tracking server_load_metrics in the app state." + ) return parser diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 9af37871d57c..60cbb58af3d9 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -4,6 +4,8 @@ import asyncio import functools from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.background import BackgroundTask, BackgroundTasks async def listen_for_disconnect(request: Request) -> None: @@ -17,9 +19,9 @@ async def listen_for_disconnect(request: Request) -> None: def with_cancellation(handler_func): """Decorator that allows a route handler to be cancelled by client disconnections. - + This does _not_ use request.is_disconnected, which does not work with - middleware. Instead this follows the pattern from + middleware. Instead this follows the pattern from starlette.StreamingResponse, which simultaneously awaits on two tasks- one to wait for an http disconnect message, and the other to do the work that we want done. When the first task finishes, the other is cancelled. @@ -57,3 +59,45 @@ def with_cancellation(handler_func): return None return wrapper + + +def decrement_server_load(request: Request): + request.app.state.server_load_metrics -= 1 + + +def load_aware_call(func): + + @functools.wraps(func) + async def wrapper(*args, raw_request: Request, **kwargs): + if not raw_request.app.state.enable_server_load_tracking: + return await func(*args, raw_request=raw_request, **kwargs) + + raw_request.app.state.server_load_metrics += 1 + try: + response = await func(*args, raw_request=raw_request, **kwargs) + except Exception: + raw_request.app.state.server_load_metrics -= 1 + raise + + if isinstance(response, (JSONResponse, StreamingResponse)): + if response.background is None: + response.background = BackgroundTask(decrement_server_load, + raw_request) + elif isinstance(response.background, BackgroundTasks): + response.background.add_task(decrement_server_load, + raw_request) + elif isinstance(response.background, BackgroundTask): + # Convert the single BackgroundTask to BackgroundTasks + # and chain the decrement_server_load task to it + tasks = BackgroundTasks() + tasks.add_task(response.background.func, + *response.background.args, + **response.background.kwargs) + tasks.add_task(decrement_server_load, raw_request) + response.background = tasks + else: + raw_request.app.state.server_load_metrics -= 1 + + return response + + return wrapper