[UX] Make vllm bench serve discover model by default and use --input-len (#30816)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-12-17 04:55:30 -05:00 committed by GitHub
parent a100152288
commit 519ef9a911
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 13 deletions

View File

@ -19,21 +19,18 @@ def server():
@pytest.mark.benchmark
def test_bench_serve(server):
# Test default model detection and input/output len
command = [
"vllm",
"bench",
"serve",
"--model",
MODEL_NAME,
"--host",
server.host,
"--port",
str(server.port),
"--dataset-name",
"random",
"--random-input-len",
"--input-len",
"32",
"--random-output-len",
"--output-len",
"4",
"--num-prompts",
"5",

View File

@ -10,8 +10,10 @@ On the client side, run:
vllm bench serve \
--backend <backend or endpoint type. Default 'openai'> \
--label <benchmark result label. Default using backend> \
--model <your_model> \
--model <your_model. Optional, defaults to first model from server> \
--dataset-name <dataset_name. Default 'random'> \
--input-len <general input length. Optional, maps to dataset-specific args> \
--output-len <general output length. Optional, maps to dataset-specific args> \
--request-rate <request_rate. Default inf> \
--num-prompts <num_prompts. Default 1000>
"""
@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
)
async def get_first_model_from_server(
base_url: str, headers: dict | None = None
) -> str:
"""Fetch the first model from the server's /v1/models endpoint."""
models_url = f"{base_url}/v1/models"
async with aiohttp.ClientSession() as session:
try:
async with session.get(models_url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
if "data" in data and len(data["data"]) > 0:
return data["data"][0]["id"]
else:
raise ValueError(
f"No models found on the server at {base_url}. "
"Make sure the server is running and has models loaded."
)
except (aiohttp.ClientError, json.JSONDecodeError) as e:
raise RuntimeError(
f"Failed to fetch models from server at {models_url}. "
"Check that:\n"
"1. The server is running\n"
"2. The server URL is correct\n"
f"Error: {e}"
) from e
class TaskType(Enum):
GENERATION = "generation"
POOLING = "pooling"
@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the model.",
required=False,
default=None,
help="Name of the model. If not specified, will fetch the first model "
"from the server's /v1/models endpoint.",
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="General input length for datasets. Maps to dataset-specific "
"input length arguments (e.g., --random-input-len, --sonnet-input-len). "
"If not specified, uses dataset defaults.",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="General output length for datasets. Maps to dataset-specific "
"output length arguments (e.g., --random-output-len, --sonnet-output-len). "
"If not specified, uses dataset defaults.",
)
parser.add_argument(
"--tokenizer",
@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
label = args.label
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
else:
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
# Fetch model from server if not specified
if args.model is None:
print("Model not specified, fetching first model from server...")
model_id = await get_first_model_from_server(base_url, headers)
print(f"Using model: {model_id}")
else:
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else model_id
tokenizer_mode = args.tokenizer_mode
tokenizer = get_tokenizer(
tokenizer_id,
tokenizer_mode=tokenizer_mode,
@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
"'--dataset-path' if required."
)
# Map general --input-len and --output-len to all dataset-specific arguments
if args.input_len is not None:
args.random_input_len = args.input_len
args.sonnet_input_len = args.input_len
if args.output_len is not None:
args.random_output_len = args.output_len
args.sonnet_output_len = args.output_len
args.sharegpt_output_len = args.output_len
args.custom_output_len = args.output_len
args.hf_output_len = args.output_len
args.spec_bench_output_len = args.output_len
args.prefix_repetition_output_len = args.output_len
# when using random datasets, default to ignoring EOS
# so generation runs to the requested length
if (