[Frontend] track server_load (#13950)

This commit is contained in:
daniel-salib 2025-03-14 09:53:17 -07:00 committed by GitHub
parent 9d2b4a70f4
commit 73deea2fdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 131 additions and 4 deletions

View File

@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
extra_headers={ extra_headers={
"Content-Type": "application/x-www-form-urlencoded" "Content-Type": "application/x-www-form-urlencoded"
}) })
@pytest.mark.parametrize(
"server_args",
[
pytest.param(["--enable-server-load-tracking"],
id="enable-server-load-tracking")
],
indirect=True,
)
@pytest.mark.asyncio
async def test_server_load(server: RemoteOpenAIServer):
# Check initial server load
response = requests.get(server.url_for("load"))
assert response.status_code == HTTPStatus.OK
assert response.json().get("server_load") == 0
def make_long_completion_request():
return requests.post(
server.url_for("v1/completions"),
headers={"Content-Type": "application/json"},
json={
"prompt": "Give me a long story",
"max_tokens": 1000,
"temperature": 0,
},
)
# Start the completion request in a background thread.
completion_future = asyncio.create_task(
asyncio.to_thread(make_long_completion_request))
# Give a short delay to ensure the request has started.
await asyncio.sleep(0.1)
# Check server load while the completion request is running.
response = requests.get(server.url_for("load"))
assert response.status_code == HTTPStatus.OK
assert response.json().get("server_load") == 1
# Wait for the completion request to finish.
await completion_future
await asyncio.sleep(0.1)
# Check server load after the completion request has finished.
response = requests.get(server.url_for("load"))
assert response.status_code == HTTPStatus.OK
assert response.json().get("server_load") == 0

View File

@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription) OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
return Response(status_code=200) return Response(status_code=200)
@router.get("/load")
async def get_server_load_metrics(request: Request):
# This endpoint returns the current server load metrics.
# It tracks requests utilizing the GPU from the following routes:
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/embeddings
# - /pooling
# - /score
# - /v1/score
# - /rerank
# - /v1/rerank
# - /v2/rerank
return JSONResponse(
content={'server_load': request.app.state.server_load_metrics})
@router.api_route("/ping", methods=["GET", "POST"]) @router.api_route("/ping", methods=["GET", "POST"])
async def ping(raw_request: Request) -> Response: async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker""" """Ping check. Endpoint required for SageMaker"""
@ -400,6 +418,7 @@ async def show_version():
@router.post("/v1/chat/completions", @router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)]) dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
handler = chat(raw_request) handler = chat(raw_request)
@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)]) @router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request) handler = completion(raw_request)
if handler is None: if handler is None:
@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)]) @router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request) handler = embedding(raw_request)
if handler is None: if handler is None:
@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@router.post("/pooling", dependencies=[Depends(validate_json_request)]) @router.post("/pooling", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request): async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request) handler = pooling(raw_request)
if handler is None: if handler is None:
@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
@router.post("/score", dependencies=[Depends(validate_json_request)]) @router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request): async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request) handler = score(raw_request)
if handler is None: if handler is None:
@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@router.post("/v1/score", dependencies=[Depends(validate_json_request)]) @router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request): async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning( logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we " "To indicate that Score API is not part of standard OpenAI API, we "
@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
@router.post("/v1/audio/transcriptions") @router.post("/v1/audio/transcriptions")
@with_cancellation @with_cancellation
@load_aware_call
async def create_transcriptions(request: Annotated[TranscriptionRequest, async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()], Form()],
raw_request: Request): raw_request: Request):
handler = transcription(raw_request) handler = transcription(raw_request)
if handler is None: if handler is None:
return base(raw_request).create_error_response( return base(raw_request).create_error_response(
@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
@router.post("/rerank", dependencies=[Depends(validate_json_request)]) @router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
@load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request): async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request) handler = rerank(raw_request)
if handler is None: if handler is None:
@ -894,6 +919,9 @@ async def init_app_state(
) if model_config.runner_type == "transcription" else None ) if model_config.runner_type == "transcription" else None
state.task = model_config.task state.task = model_config.task
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
def create_server_socket(addr: tuple[str, int]) -> socket.socket: def create_server_socket(addr: tuple[str, int]) -> socket.socket:
family = socket.AF_INET family = socket.AF_INET

View File

@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action='store_true', action='store_true',
default=False, default=False,
help="If set to True, enable prompt_tokens_details in usage.") help="If set to True, enable prompt_tokens_details in usage.")
parser.add_argument(
"--enable-server-load-tracking",
action='store_true',
default=False,
help=
"If set to True, enable tracking server_load_metrics in the app state."
)
return parser return parser

View File

@ -4,6 +4,8 @@ import asyncio
import functools import functools
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks
async def listen_for_disconnect(request: Request) -> None: async def listen_for_disconnect(request: Request) -> None:
@ -17,9 +19,9 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func): def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client """Decorator that allows a route handler to be cancelled by client
disconnections. disconnections.
This does _not_ use request.is_disconnected, which does not work with This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled. want done. When the first task finishes, the other is cancelled.
@ -57,3 +59,45 @@ def with_cancellation(handler_func):
return None return None
return wrapper return wrapper
def decrement_server_load(request: Request):
request.app.state.server_load_metrics -= 1
def load_aware_call(func):
@functools.wraps(func)
async def wrapper(*args, raw_request: Request, **kwargs):
if not raw_request.app.state.enable_server_load_tracking:
return await func(*args, raw_request=raw_request, **kwargs)
raw_request.app.state.server_load_metrics += 1
try:
response = await func(*args, raw_request=raw_request, **kwargs)
except Exception:
raw_request.app.state.server_load_metrics -= 1
raise
if isinstance(response, (JSONResponse, StreamingResponse)):
if response.background is None:
response.background = BackgroundTask(decrement_server_load,
raw_request)
elif isinstance(response.background, BackgroundTasks):
response.background.add_task(decrement_server_load,
raw_request)
elif isinstance(response.background, BackgroundTask):
# Convert the single BackgroundTask to BackgroundTasks
# and chain the decrement_server_load task to it
tasks = BackgroundTasks()
tasks.add_task(response.background.func,
*response.background.args,
**response.background.kwargs)
tasks.add_task(decrement_server_load, raw_request)
response.background = tasks
else:
raw_request.app.state.server_load_metrics -= 1
return response
return wrapper