diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 066b8fe83438..725b7df8b187 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -8,8 +8,9 @@ import os import sys import time import traceback +from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import Optional, Union +from typing import Optional, Protocol, Union import aiohttp from tqdm.asyncio import tqdm @@ -92,6 +93,16 @@ class RequestFuncOutput: start_time: float = 0.0 +class RequestFunc(Protocol): + def __call__( + self, + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: Optional[tqdm] = None, + ) -> Awaitable[RequestFuncOutput]: + ... + + async def async_request_openai_completions( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, @@ -507,7 +518,7 @@ async def async_request_openai_embeddings( # TODO: Add more request functions for different API protocols. -ASYNC_REQUEST_FUNCS = { +ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { "vllm": async_request_openai_completions, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py index 7e836158386a..87fc16b55012 100644 --- a/vllm/benchmarks/lib/ready_checker.py +++ b/vllm/benchmarks/lib/ready_checker.py @@ -8,11 +8,12 @@ import time import aiohttp from tqdm.asyncio import tqdm -from .endpoint_request_func import RequestFuncInput, RequestFuncOutput +from .endpoint_request_func import (RequestFunc, RequestFuncInput, + RequestFuncOutput) async def wait_for_endpoint( - request_func, + request_func: RequestFunc, test_input: RequestFuncInput, session: aiohttp.ClientSession, timeout_seconds: int = 600,