mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 12:25:50 +08:00
[UX] Make vllm bench serve discover model by default and use --input-len (#30816)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
a100152288
commit
519ef9a911
@ -19,21 +19,18 @@ def server():
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve(server):
|
||||
# Test default model detection and input/output len
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
server.host,
|
||||
"--port",
|
||||
str(server.port),
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"--input-len",
|
||||
"32",
|
||||
"--random-output-len",
|
||||
"--output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
|
||||
@ -10,8 +10,10 @@ On the client side, run:
|
||||
vllm bench serve \
|
||||
--backend <backend or endpoint type. Default 'openai'> \
|
||||
--label <benchmark result label. Default using backend> \
|
||||
--model <your_model> \
|
||||
--model <your_model. Optional, defaults to first model from server> \
|
||||
--dataset-name <dataset_name. Default 'random'> \
|
||||
--input-len <general input length. Optional, maps to dataset-specific args> \
|
||||
--output-len <general output length. Optional, maps to dataset-specific args> \
|
||||
--request-rate <request_rate. Default inf> \
|
||||
--num-prompts <num_prompts. Default 1000>
|
||||
"""
|
||||
@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
|
||||
)
|
||||
|
||||
|
||||
async def get_first_model_from_server(
|
||||
base_url: str, headers: dict | None = None
|
||||
) -> str:
|
||||
"""Fetch the first model from the server's /v1/models endpoint."""
|
||||
models_url = f"{base_url}/v1/models"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(models_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
if "data" in data and len(data["data"]) > 0:
|
||||
return data["data"][0]["id"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No models found on the server at {base_url}. "
|
||||
"Make sure the server is running and has models loaded."
|
||||
)
|
||||
except (aiohttp.ClientError, json.JSONDecodeError) as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to fetch models from server at {models_url}. "
|
||||
"Check that:\n"
|
||||
"1. The server is running\n"
|
||||
"2. The server URL is correct\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATION = "generation"
|
||||
POOLING = "pooling"
|
||||
@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model.",
|
||||
required=False,
|
||||
default=None,
|
||||
help="Name of the model. If not specified, will fetch the first model "
|
||||
"from the server's /v1/models endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="General input length for datasets. Maps to dataset-specific "
|
||||
"input length arguments (e.g., --random-input-len, --sonnet-input-len). "
|
||||
"If not specified, uses dataset defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="General output length for datasets. Maps to dataset-specific "
|
||||
"output length arguments (e.g., --random-output-len, --sonnet-output-len). "
|
||||
"If not specified, uses dataset defaults.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
|
||||
|
||||
label = args.label
|
||||
model_id = args.model
|
||||
model_name = args.served_model_name
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
else:
|
||||
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
|
||||
|
||||
# Fetch model from server if not specified
|
||||
if args.model is None:
|
||||
print("Model not specified, fetching first model from server...")
|
||||
model_id = await get_first_model_from_server(base_url, headers)
|
||||
print(f"Using model: {model_id}")
|
||||
else:
|
||||
model_id = args.model
|
||||
|
||||
model_name = args.served_model_name
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else model_id
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_id,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
"'--dataset-path' if required."
|
||||
)
|
||||
|
||||
# Map general --input-len and --output-len to all dataset-specific arguments
|
||||
if args.input_len is not None:
|
||||
args.random_input_len = args.input_len
|
||||
args.sonnet_input_len = args.input_len
|
||||
|
||||
if args.output_len is not None:
|
||||
args.random_output_len = args.output_len
|
||||
args.sonnet_output_len = args.output_len
|
||||
args.sharegpt_output_len = args.output_len
|
||||
args.custom_output_len = args.output_len
|
||||
args.hf_output_len = args.output_len
|
||||
args.spec_bench_output_len = args.output_len
|
||||
args.prefix_repetition_output_len = args.output_len
|
||||
|
||||
# when using random datasets, default to ignoring EOS
|
||||
# so generation runs to the requested length
|
||||
if (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user