diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index 90d685c966d3e..c579b38069864 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -19,21 +19,18 @@ def server(): @pytest.mark.benchmark def test_bench_serve(server): + # Test default model detection and input/output len command = [ "vllm", "bench", "serve", - "--model", - MODEL_NAME, "--host", server.host, "--port", str(server.port), - "--dataset-name", - "random", - "--random-input-len", + "--input-len", "32", - "--random-output-len", + "--output-len", "4", "--num-prompts", "5", diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index f5d8ea5a975a9..12756d1700c9f 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -10,8 +10,10 @@ On the client side, run: vllm bench serve \ --backend \ --label \ - --model \ + --model \ --dataset-name \ + --input-len \ + --output-len \ --request-rate \ --num-prompts """ @@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a ) +async def get_first_model_from_server( + base_url: str, headers: dict | None = None +) -> str: + """Fetch the first model from the server's /v1/models endpoint.""" + models_url = f"{base_url}/v1/models" + async with aiohttp.ClientSession() as session: + try: + async with session.get(models_url, headers=headers) as response: + response.raise_for_status() + data = await response.json() + if "data" in data and len(data["data"]) > 0: + return data["data"][0]["id"] + else: + raise ValueError( + f"No models found on the server at {base_url}. " + "Make sure the server is running and has models loaded." + ) + except (aiohttp.ClientError, json.JSONDecodeError) as e: + raise RuntimeError( + f"Failed to fetch models from server at {models_url}. " + "Check that:\n" + "1. The server is running\n" + "2. The server URL is correct\n" + f"Error: {e}" + ) from e + + class TaskType(Enum): GENERATION = "generation" POOLING = "pooling" @@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--model", type=str, - required=True, - help="Name of the model.", + required=False, + default=None, + help="Name of the model. If not specified, will fetch the first model " + "from the server's /v1/models endpoint.", + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="General input length for datasets. Maps to dataset-specific " + "input length arguments (e.g., --random-input-len, --sonnet-input-len). " + "If not specified, uses dataset defaults.", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="General output length for datasets. Maps to dataset-specific " + "output length arguments (e.g., --random-output-len, --sonnet-output-len). " + "If not specified, uses dataset defaults.", ) parser.add_argument( "--tokenizer", @@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") label = args.label - model_id = args.model - model_name = args.served_model_name - tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model - tokenizer_mode = args.tokenizer_mode if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" @@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: else: raise ValueError("Invalid header format. Please use KEY=VALUE format.") + # Fetch model from server if not specified + if args.model is None: + print("Model not specified, fetching first model from server...") + model_id = await get_first_model_from_server(base_url, headers) + print(f"Using model: {model_id}") + else: + model_id = args.model + + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else model_id + tokenizer_mode = args.tokenizer_mode + tokenizer = get_tokenizer( tokenizer_id, tokenizer_mode=tokenizer_mode, @@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: "'--dataset-path' if required." ) + # Map general --input-len and --output-len to all dataset-specific arguments + if args.input_len is not None: + args.random_input_len = args.input_len + args.sonnet_input_len = args.input_len + + if args.output_len is not None: + args.random_output_len = args.output_len + args.sonnet_output_len = args.output_len + args.sharegpt_output_len = args.output_len + args.custom_output_len = args.output_len + args.hf_output_len = args.output_len + args.spec_bench_output_len = args.output_len + args.prefix_repetition_output_len = args.output_len + # when using random datasets, default to ignoring EOS # so generation runs to the requested length if (