mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 04:06:49 +08:00
[UX] Make vllm bench serve discover model by default and use --input-len (#30816)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
a100152288
commit
519ef9a911
@ -19,21 +19,18 @@ def server():
|
|||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_serve(server):
|
def test_bench_serve(server):
|
||||||
|
# Test default model detection and input/output len
|
||||||
command = [
|
command = [
|
||||||
"vllm",
|
"vllm",
|
||||||
"bench",
|
"bench",
|
||||||
"serve",
|
"serve",
|
||||||
"--model",
|
|
||||||
MODEL_NAME,
|
|
||||||
"--host",
|
"--host",
|
||||||
server.host,
|
server.host,
|
||||||
"--port",
|
"--port",
|
||||||
str(server.port),
|
str(server.port),
|
||||||
"--dataset-name",
|
"--input-len",
|
||||||
"random",
|
|
||||||
"--random-input-len",
|
|
||||||
"32",
|
"32",
|
||||||
"--random-output-len",
|
"--output-len",
|
||||||
"4",
|
"4",
|
||||||
"--num-prompts",
|
"--num-prompts",
|
||||||
"5",
|
"5",
|
||||||
|
|||||||
@ -10,8 +10,10 @@ On the client side, run:
|
|||||||
vllm bench serve \
|
vllm bench serve \
|
||||||
--backend <backend or endpoint type. Default 'openai'> \
|
--backend <backend or endpoint type. Default 'openai'> \
|
||||||
--label <benchmark result label. Default using backend> \
|
--label <benchmark result label. Default using backend> \
|
||||||
--model <your_model> \
|
--model <your_model. Optional, defaults to first model from server> \
|
||||||
--dataset-name <dataset_name. Default 'random'> \
|
--dataset-name <dataset_name. Default 'random'> \
|
||||||
|
--input-len <general input length. Optional, maps to dataset-specific args> \
|
||||||
|
--output-len <general output length. Optional, maps to dataset-specific args> \
|
||||||
--request-rate <request_rate. Default inf> \
|
--request-rate <request_rate. Default inf> \
|
||||||
--num-prompts <num_prompts. Default 1000>
|
--num-prompts <num_prompts. Default 1000>
|
||||||
"""
|
"""
|
||||||
@ -57,6 +59,33 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_first_model_from_server(
|
||||||
|
base_url: str, headers: dict | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Fetch the first model from the server's /v1/models endpoint."""
|
||||||
|
models_url = f"{base_url}/v1/models"
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.get(models_url, headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
data = await response.json()
|
||||||
|
if "data" in data and len(data["data"]) > 0:
|
||||||
|
return data["data"][0]["id"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No models found on the server at {base_url}. "
|
||||||
|
"Make sure the server is running and has models loaded."
|
||||||
|
)
|
||||||
|
except (aiohttp.ClientError, json.JSONDecodeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to fetch models from server at {models_url}. "
|
||||||
|
"Check that:\n"
|
||||||
|
"1. The server is running\n"
|
||||||
|
"2. The server URL is correct\n"
|
||||||
|
f"Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class TaskType(Enum):
|
class TaskType(Enum):
|
||||||
GENERATION = "generation"
|
GENERATION = "generation"
|
||||||
POOLING = "pooling"
|
POOLING = "pooling"
|
||||||
@ -1025,8 +1054,26 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=False,
|
||||||
help="Name of the model.",
|
default=None,
|
||||||
|
help="Name of the model. If not specified, will fetch the first model "
|
||||||
|
"from the server's /v1/models endpoint.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="General input length for datasets. Maps to dataset-specific "
|
||||||
|
"input length arguments (e.g., --random-input-len, --sonnet-input-len). "
|
||||||
|
"If not specified, uses dataset defaults.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="General output length for datasets. Maps to dataset-specific "
|
||||||
|
"output length arguments (e.g., --random-output-len, --sonnet-output-len). "
|
||||||
|
"If not specified, uses dataset defaults.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
@ -1332,10 +1379,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
|
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
|
||||||
|
|
||||||
label = args.label
|
label = args.label
|
||||||
model_id = args.model
|
|
||||||
model_name = args.served_model_name
|
|
||||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
|
||||||
tokenizer_mode = args.tokenizer_mode
|
|
||||||
|
|
||||||
if args.base_url is not None:
|
if args.base_url is not None:
|
||||||
api_url = f"{args.base_url}{args.endpoint}"
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
@ -1356,6 +1399,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
|
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
|
||||||
|
|
||||||
|
# Fetch model from server if not specified
|
||||||
|
if args.model is None:
|
||||||
|
print("Model not specified, fetching first model from server...")
|
||||||
|
model_id = await get_first_model_from_server(base_url, headers)
|
||||||
|
print(f"Using model: {model_id}")
|
||||||
|
else:
|
||||||
|
model_id = args.model
|
||||||
|
|
||||||
|
model_name = args.served_model_name
|
||||||
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else model_id
|
||||||
|
tokenizer_mode = args.tokenizer_mode
|
||||||
|
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
tokenizer_id,
|
tokenizer_id,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
@ -1368,6 +1423,20 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
"'--dataset-path' if required."
|
"'--dataset-path' if required."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Map general --input-len and --output-len to all dataset-specific arguments
|
||||||
|
if args.input_len is not None:
|
||||||
|
args.random_input_len = args.input_len
|
||||||
|
args.sonnet_input_len = args.input_len
|
||||||
|
|
||||||
|
if args.output_len is not None:
|
||||||
|
args.random_output_len = args.output_len
|
||||||
|
args.sonnet_output_len = args.output_len
|
||||||
|
args.sharegpt_output_len = args.output_len
|
||||||
|
args.custom_output_len = args.output_len
|
||||||
|
args.hf_output_len = args.output_len
|
||||||
|
args.spec_bench_output_len = args.output_len
|
||||||
|
args.prefix_repetition_output_len = args.output_len
|
||||||
|
|
||||||
# when using random datasets, default to ignoring EOS
|
# when using random datasets, default to ignoring EOS
|
||||||
# so generation runs to the requested length
|
# so generation runs to the requested length
|
||||||
if (
|
if (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user