mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[Frontend] track server_load (#13950)
This commit is contained in:
parent
9d2b4a70f4
commit
73deea2fdb
@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
|||||||
extra_headers={
|
extra_headers={
|
||||||
"Content-Type": "application/x-www-form-urlencoded"
|
"Content-Type": "application/x-www-form-urlencoded"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"server_args",
|
||||||
|
[
|
||||||
|
pytest.param(["--enable-server-load-tracking"],
|
||||||
|
id="enable-server-load-tracking")
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_load(server: RemoteOpenAIServer):
|
||||||
|
# Check initial server load
|
||||||
|
response = requests.get(server.url_for("load"))
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
assert response.json().get("server_load") == 0
|
||||||
|
|
||||||
|
def make_long_completion_request():
|
||||||
|
return requests.post(
|
||||||
|
server.url_for("v1/completions"),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={
|
||||||
|
"prompt": "Give me a long story",
|
||||||
|
"max_tokens": 1000,
|
||||||
|
"temperature": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the completion request in a background thread.
|
||||||
|
completion_future = asyncio.create_task(
|
||||||
|
asyncio.to_thread(make_long_completion_request))
|
||||||
|
|
||||||
|
# Give a short delay to ensure the request has started.
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Check server load while the completion request is running.
|
||||||
|
response = requests.get(server.url_for("load"))
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
assert response.json().get("server_load") == 1
|
||||||
|
|
||||||
|
# Wait for the completion request to finish.
|
||||||
|
await completion_future
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Check server load after the completion request has finished.
|
||||||
|
response = requests.get(server.url_for("load"))
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
assert response.json().get("server_load") == 0
|
||||||
|
|||||||
@ -80,7 +80,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
|
|||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
OpenAIServingTranscription)
|
OpenAIServingTranscription)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.entrypoints.utils import with_cancellation
|
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||||
@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/load")
|
||||||
|
async def get_server_load_metrics(request: Request):
|
||||||
|
# This endpoint returns the current server load metrics.
|
||||||
|
# It tracks requests utilizing the GPU from the following routes:
|
||||||
|
# - /v1/chat/completions
|
||||||
|
# - /v1/completions
|
||||||
|
# - /v1/audio/transcriptions
|
||||||
|
# - /v1/embeddings
|
||||||
|
# - /pooling
|
||||||
|
# - /score
|
||||||
|
# - /v1/score
|
||||||
|
# - /rerank
|
||||||
|
# - /v1/rerank
|
||||||
|
# - /v2/rerank
|
||||||
|
return JSONResponse(
|
||||||
|
content={'server_load': request.app.state.server_load_metrics})
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/ping", methods=["GET", "POST"])
|
@router.api_route("/ping", methods=["GET", "POST"])
|
||||||
async def ping(raw_request: Request) -> Response:
|
async def ping(raw_request: Request) -> Response:
|
||||||
"""Ping check. Endpoint required for SageMaker"""
|
"""Ping check. Endpoint required for SageMaker"""
|
||||||
@ -400,6 +418,7 @@ async def show_version():
|
|||||||
@router.post("/v1/chat/completions",
|
@router.post("/v1/chat/completions",
|
||||||
dependencies=[Depends(validate_json_request)])
|
dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_chat_completion(request: ChatCompletionRequest,
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
handler = chat(raw_request)
|
handler = chat(raw_request)
|
||||||
@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
|
|
||||||
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
handler = completion(raw_request)
|
handler = completion(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
|
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
handler = embedding(raw_request)
|
handler = embedding(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
|
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||||
handler = pooling(raw_request)
|
handler = pooling(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||||
handler = score(raw_request)
|
handler = score(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||||
@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/v1/audio/transcriptions")
|
@router.post("/v1/audio/transcriptions")
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||||
Form()],
|
Form()],
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
|
|
||||||
handler = transcription(raw_request)
|
handler = transcription(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
return base(raw_request).create_error_response(
|
return base(raw_request).create_error_response(
|
||||||
@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
|||||||
|
|
||||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||||
handler = rerank(raw_request)
|
handler = rerank(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
@ -894,6 +919,9 @@ async def init_app_state(
|
|||||||
) if model_config.runner_type == "transcription" else None
|
) if model_config.runner_type == "transcription" else None
|
||||||
state.task = model_config.task
|
state.task = model_config.task
|
||||||
|
|
||||||
|
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||||
|
state.server_load_metrics = 0
|
||||||
|
|
||||||
|
|
||||||
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||||
family = socket.AF_INET
|
family = socket.AF_INET
|
||||||
|
|||||||
@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help="If set to True, enable prompt_tokens_details in usage.")
|
help="If set to True, enable prompt_tokens_details in usage.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-server-load-tracking",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help=
|
||||||
|
"If set to True, enable tracking server_load_metrics in the app state."
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import asyncio
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from starlette.background import BackgroundTask, BackgroundTasks
|
||||||
|
|
||||||
|
|
||||||
async def listen_for_disconnect(request: Request) -> None:
|
async def listen_for_disconnect(request: Request) -> None:
|
||||||
@ -17,9 +19,9 @@ async def listen_for_disconnect(request: Request) -> None:
|
|||||||
def with_cancellation(handler_func):
|
def with_cancellation(handler_func):
|
||||||
"""Decorator that allows a route handler to be cancelled by client
|
"""Decorator that allows a route handler to be cancelled by client
|
||||||
disconnections.
|
disconnections.
|
||||||
|
|
||||||
This does _not_ use request.is_disconnected, which does not work with
|
This does _not_ use request.is_disconnected, which does not work with
|
||||||
middleware. Instead this follows the pattern from
|
middleware. Instead this follows the pattern from
|
||||||
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
||||||
to wait for an http disconnect message, and the other to do the work that we
|
to wait for an http disconnect message, and the other to do the work that we
|
||||||
want done. When the first task finishes, the other is cancelled.
|
want done. When the first task finishes, the other is cancelled.
|
||||||
@ -57,3 +59,45 @@ def with_cancellation(handler_func):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def decrement_server_load(request: Request):
|
||||||
|
request.app.state.server_load_metrics -= 1
|
||||||
|
|
||||||
|
|
||||||
|
def load_aware_call(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args, raw_request: Request, **kwargs):
|
||||||
|
if not raw_request.app.state.enable_server_load_tracking:
|
||||||
|
return await func(*args, raw_request=raw_request, **kwargs)
|
||||||
|
|
||||||
|
raw_request.app.state.server_load_metrics += 1
|
||||||
|
try:
|
||||||
|
response = await func(*args, raw_request=raw_request, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
raw_request.app.state.server_load_metrics -= 1
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(response, (JSONResponse, StreamingResponse)):
|
||||||
|
if response.background is None:
|
||||||
|
response.background = BackgroundTask(decrement_server_load,
|
||||||
|
raw_request)
|
||||||
|
elif isinstance(response.background, BackgroundTasks):
|
||||||
|
response.background.add_task(decrement_server_load,
|
||||||
|
raw_request)
|
||||||
|
elif isinstance(response.background, BackgroundTask):
|
||||||
|
# Convert the single BackgroundTask to BackgroundTasks
|
||||||
|
# and chain the decrement_server_load task to it
|
||||||
|
tasks = BackgroundTasks()
|
||||||
|
tasks.add_task(response.background.func,
|
||||||
|
*response.background.args,
|
||||||
|
**response.background.kwargs)
|
||||||
|
tasks.add_task(decrement_server_load, raw_request)
|
||||||
|
response.background = tasks
|
||||||
|
else:
|
||||||
|
raw_request.app.state.server_load_metrics -= 1
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user