diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 5d2ac66e5ab94..2c1a051cc9c97 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -63,6 +63,7 @@ class RequestArgs(NamedTuple): stream: bool limit_min_tokens: int # Use negative value for no limit limit_max_tokens: int # Use negative value for no limit + timeout_sec: int class BenchmarkArgs(NamedTuple): @@ -214,6 +215,7 @@ async def send_request( stream: bool = True, min_tokens: int | None = None, max_tokens: int | None = None, + timeout_sec: int = 120, ) -> ServerResponse: payload = { "model": model, @@ -235,10 +237,16 @@ async def send_request( headers = {"Content-Type": "application/json"} # Calculate the timeout for the request - timeout_sec = 120 if max_tokens is not None: # Assume TPOT of 200ms and use max_tokens to determine timeout - timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) + token_based_timeout = int(max_tokens * 0.2) + if token_based_timeout > timeout_sec: + timeout_sec = token_based_timeout + logger.info( + "Using timeout of %ds based on max_tokens %d", + timeout_sec, + max_tokens, + ) timeout = aiohttp.ClientTimeout(total=timeout_sec) valid_response = True @@ -409,6 +417,7 @@ async def send_turn( req_args.stream, min_tokens, max_tokens, + req_args.timeout_sec, ) if response.valid is False: @@ -676,8 +685,18 @@ async def client_main( except asyncio.exceptions.TimeoutError: num_failures += 1 - logger.exception( - f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + logger.error( + "%sClient %d - Timeout during conversation ID %s (turn: %d). " + "Base timeout is %ss (set with --request-timeout-sec), but the " + "effective timeout may be longer based on max_tokens. If this " + "is unexpected, consider increasing the timeout or checking " + "model performance.%s", + Color.RED, + client_id, + conv_id, + current_turn, + req_args.timeout_sec, + Color.RESET, ) break # Exit gracefully instead of raising an error @@ -815,6 +834,9 @@ def get_client_config( "Invalid min/max tokens limits (min should not be larger than max)" ) + if args.request_timeout_sec <= 0: + raise ValueError("Request timeout must be a positive number") + # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" model_name = args.served_model_name if args.served_model_name else args.model @@ -825,6 +847,7 @@ def get_client_config( stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, + timeout_sec=args.request_timeout_sec, ) return client_args, req_args @@ -968,7 +991,7 @@ async def main_mp( f"(is alive: {client.is_alive()}){Color.RESET}" ) - client.join(timeout=120) + client.join(timeout=req_args.timeout_sec + 1) if client.is_alive(): logger.warning( @@ -1351,6 +1374,13 @@ async def main() -> None: action="store_true", help="Verify the LLM output (compare to the answers in the input JSON file)", ) + parser.add_argument( + "--request-timeout-sec", + type=int, + default=120, + help="Timeout in seconds for each API request (default: 120). " + "Automatically increased if max tokens imply longer decoding.", + ) parser.add_argument( "--no-stream",