[Benchmark] Show E2EL by default for pooling models (#27014)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-16 20:47:09 +08:00 committed by GitHub
parent dcbb3f1871
commit 334535b6fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,7 +58,7 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
class TaskType(Enum):
GENERATION = "generation"
EMBEDDING = "embedding"
POOLING = "pooling"
@dataclass
@ -1084,10 +1084,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
default=None,
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
'Allowed metric names are "ttft", "tpot", "itl", "e2el". ',
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'If not specified, defaults to "ttft,tpot,itl" for generative models '
'and "e2el" for pooling models.',
)
parser.add_argument(
"--metric-percentiles",
@ -1310,7 +1312,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
goodput_config_dict = check_goodput_args(args)
backend = args.backend
task_type = TaskType.EMBEDDING if "embeddings" in backend else TaskType.GENERATION
task_type = (
TaskType.POOLING
if "embeddings" in backend or "rerank" in backend
else TaskType.GENERATION
)
# Collect the sampling parameters.
if task_type == TaskType.GENERATION:
@ -1336,12 +1342,17 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
default_percentile_metrics = "ttft,tpot,itl"
else:
sampling_params = {}
default_percentile_metrics = "e2el"
extra_body = args.extra_body or {}
extra_body = {**sampling_params, **extra_body}
percentile_metrics: str = args.percentile_metrics or default_percentile_metrics
# Avoid GC processing "static" data - reduce pause times.
gc.collect()
gc.freeze()
@ -1360,7 +1371,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentile_metrics=percentile_metrics.split(","),
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,