diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 3c85a1e8fdd9..79e39fb86b32 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -58,7 +58,7 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a class TaskType(Enum): GENERATION = "generation" - EMBEDDING = "embedding" + POOLING = "pooling" @dataclass @@ -1084,10 +1084,12 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--percentile-metrics", type=str, - default="ttft,tpot,itl", + default=None, help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ', + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'If not specified, defaults to "ttft,tpot,itl" for generative models ' + 'and "e2el" for pooling models.', ) parser.add_argument( "--metric-percentiles", @@ -1310,7 +1312,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: goodput_config_dict = check_goodput_args(args) backend = args.backend - task_type = TaskType.EMBEDDING if "embeddings" in backend else TaskType.GENERATION + task_type = ( + TaskType.POOLING + if "embeddings" in backend or "rerank" in backend + else TaskType.GENERATION + ) # Collect the sampling parameters. if task_type == TaskType.GENERATION: @@ -1336,12 +1342,17 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + default_percentile_metrics = "ttft,tpot,itl" else: sampling_params = {} + default_percentile_metrics = "e2el" extra_body = args.extra_body or {} extra_body = {**sampling_params, **extra_body} + percentile_metrics: str = args.percentile_metrics or default_percentile_metrics + # Avoid GC processing "static" data - reduce pause times. gc.collect() gc.freeze() @@ -1360,7 +1371,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentile_metrics=percentile_metrics.split(","), selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict,