Fix boolean nested params, add dict format support, and enhance plotting for vllm bench sweep (#29025)

Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Copilot 2025-12-02 20:40:56 +00:00 committed by GitHub
parent a2b053dc85
commit 1c593e117d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 614 additions and 22 deletions

View File

@ -0,0 +1,257 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import tempfile
from pathlib import Path
import pytest
from vllm.benchmarks.sweep.param_sweep import ParameterSweep, ParameterSweepItem
class TestParameterSweepItem:
"""Test ParameterSweepItem functionality."""
@pytest.mark.parametrize(
"input_dict,expected",
[
(
{"compilation_config.use_inductor_graph_partition": False},
"--compilation-config.use_inductor_graph_partition=false",
),
(
{"compilation_config.use_inductor_graph_partition": True},
"--compilation-config.use_inductor_graph_partition=true",
),
(
{"compilation_config.use_inductor": False},
"--compilation-config.use_inductor=false",
),
(
{"compilation_config.use_inductor": True},
"--compilation-config.use_inductor=true",
),
],
)
def test_nested_boolean_params(self, input_dict, expected):
"""Test that nested boolean params use =true/false syntax."""
item = ParameterSweepItem.from_record(input_dict)
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
assert expected in cmd
@pytest.mark.parametrize(
"input_dict,expected",
[
({"enable_prefix_caching": False}, "--no-enable-prefix-caching"),
({"enable_prefix_caching": True}, "--enable-prefix-caching"),
({"disable_log_stats": False}, "--no-disable-log-stats"),
({"disable_log_stats": True}, "--disable-log-stats"),
],
)
def test_non_nested_boolean_params(self, input_dict, expected):
"""Test that non-nested boolean params use --no- prefix."""
item = ParameterSweepItem.from_record(input_dict)
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
assert expected in cmd
@pytest.mark.parametrize(
"compilation_config",
[
{"cudagraph_mode": "full", "mode": 2, "use_inductor_graph_partition": True},
{
"cudagraph_mode": "piecewise",
"mode": 3,
"use_inductor_graph_partition": False,
},
],
)
def test_nested_dict_value(self, compilation_config):
"""Test that nested dict values are serialized as JSON."""
item = ParameterSweepItem.from_record(
{"compilation_config": compilation_config}
)
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
assert "--compilation-config" in cmd
# The dict should be JSON serialized
idx = cmd.index("--compilation-config")
assert json.loads(cmd[idx + 1]) == compilation_config
@pytest.mark.parametrize(
"input_dict,expected_key,expected_value",
[
({"model": "test-model"}, "--model", "test-model"),
({"max_tokens": 100}, "--max-tokens", "100"),
({"temperature": 0.7}, "--temperature", "0.7"),
],
)
def test_string_and_numeric_values(self, input_dict, expected_key, expected_value):
"""Test that string and numeric values are handled correctly."""
item = ParameterSweepItem.from_record(input_dict)
cmd = item.apply_to_cmd(["vllm", "serve"])
assert expected_key in cmd
assert expected_value in cmd
@pytest.mark.parametrize(
"input_dict,expected_key,key_idx_offset",
[
({"max_tokens": 200}, "--max-tokens", 1),
({"enable_prefix_caching": False}, "--no-enable-prefix-caching", 0),
],
)
def test_replace_existing_parameter(self, input_dict, expected_key, key_idx_offset):
"""Test that existing parameters in cmd are replaced."""
item = ParameterSweepItem.from_record(input_dict)
if key_idx_offset == 1:
# Key-value pair
cmd = item.apply_to_cmd(["vllm", "serve", "--max-tokens", "100", "model"])
assert expected_key in cmd
idx = cmd.index(expected_key)
assert cmd[idx + 1] == "200"
assert "100" not in cmd
else:
# Boolean flag
cmd = item.apply_to_cmd(
["vllm", "serve", "--enable-prefix-caching", "model"]
)
assert expected_key in cmd
assert "--enable-prefix-caching" not in cmd
class TestParameterSweep:
"""Test ParameterSweep functionality."""
def test_from_records_list(self):
"""Test creating ParameterSweep from a list of records."""
records = [
{"max_tokens": 100, "temperature": 0.7},
{"max_tokens": 200, "temperature": 0.9},
]
sweep = ParameterSweep.from_records(records)
assert len(sweep) == 2
assert sweep[0]["max_tokens"] == 100
assert sweep[1]["max_tokens"] == 200
def test_read_from_dict(self):
"""Test creating ParameterSweep from a dict format."""
data = {
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9},
}
sweep = ParameterSweep.read_from_dict(data)
assert len(sweep) == 2
# Check that items have the _benchmark_name field
names = {item["_benchmark_name"] for item in sweep}
assert names == {"experiment1", "experiment2"}
# Check that parameters are preserved
for item in sweep:
if item["_benchmark_name"] == "experiment1":
assert item["max_tokens"] == 100
assert item["temperature"] == 0.7
elif item["_benchmark_name"] == "experiment2":
assert item["max_tokens"] == 200
assert item["temperature"] == 0.9
def test_read_json_list_format(self):
"""Test reading JSON file with list format."""
records = [
{"max_tokens": 100, "temperature": 0.7},
{"max_tokens": 200, "temperature": 0.9},
]
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(records, f)
temp_path = Path(f.name)
try:
sweep = ParameterSweep.read_json(temp_path)
assert len(sweep) == 2
assert sweep[0]["max_tokens"] == 100
assert sweep[1]["max_tokens"] == 200
finally:
temp_path.unlink()
def test_read_json_dict_format(self):
"""Test reading JSON file with dict format."""
data = {
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(data, f)
temp_path = Path(f.name)
try:
sweep = ParameterSweep.read_json(temp_path)
assert len(sweep) == 2
# Check that items have the _benchmark_name field
names = {item["_benchmark_name"] for item in sweep}
assert names == {"experiment1", "experiment2"}
finally:
temp_path.unlink()
def test_unique_benchmark_names_validation(self):
"""Test that duplicate _benchmark_name values raise an error."""
# Test with duplicate names in list format
records = [
{"_benchmark_name": "exp1", "max_tokens": 100},
{"_benchmark_name": "exp1", "max_tokens": 200},
]
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
ParameterSweep.from_records(records)
def test_unique_benchmark_names_multiple_duplicates(self):
"""Test validation with multiple duplicate names."""
records = [
{"_benchmark_name": "exp1", "max_tokens": 100},
{"_benchmark_name": "exp1", "max_tokens": 200},
{"_benchmark_name": "exp2", "max_tokens": 300},
{"_benchmark_name": "exp2", "max_tokens": 400},
]
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
ParameterSweep.from_records(records)
def test_no_benchmark_names_allowed(self):
"""Test that records without _benchmark_name are allowed."""
records = [
{"max_tokens": 100, "temperature": 0.7},
{"max_tokens": 200, "temperature": 0.9},
]
sweep = ParameterSweep.from_records(records)
assert len(sweep) == 2
def test_mixed_benchmark_names_allowed(self):
"""Test that mixing records with and without _benchmark_name is allowed."""
records = [
{"_benchmark_name": "exp1", "max_tokens": 100},
{"max_tokens": 200, "temperature": 0.9},
]
sweep = ParameterSweep.from_records(records)
assert len(sweep) == 2
class TestParameterSweepItemKeyNormalization:
"""Test key normalization in ParameterSweepItem."""
def test_underscore_to_hyphen_conversion(self):
"""Test that underscores are converted to hyphens in CLI."""
item = ParameterSweepItem.from_record({"max_tokens": 100})
cmd = item.apply_to_cmd(["vllm", "serve"])
assert "--max-tokens" in cmd
def test_nested_key_preserves_suffix(self):
"""Test that nested keys preserve the suffix format."""
# The suffix after the dot should preserve underscores
item = ParameterSweepItem.from_record(
{"compilation_config.some_nested_param": "value"}
)
cmd = item.apply_to_cmd(["vllm", "serve"])
# The prefix (compilation_config) gets converted to hyphens,
# but the suffix (some_nested_param) is preserved
assert any("compilation-config.some_nested_param" in arg for arg in cmd)

View File

@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pandas as pd
import pytest
from vllm.benchmarks.sweep.plot import (
PlotEqualTo,
PlotFilterBase,
PlotFilters,
PlotGreaterThan,
PlotGreaterThanOrEqualTo,
PlotLessThan,
PlotLessThanOrEqualTo,
PlotNotEqualTo,
)
class TestPlotFilters:
"""Test PlotFilter functionality including 'inf' edge case."""
def setup_method(self):
"""Create sample DataFrames for testing."""
# DataFrame with numeric values
self.df_numeric = pd.DataFrame(
{
"request_rate": [1.0, 5.0, 10.0, 50.0, 100.0],
"value": [10, 20, 30, 40, 50],
}
)
# DataFrame with float('inf') - note: string "inf" values are coerced
# to float when loading data, so we only test with float('inf')
self.df_inf_float = pd.DataFrame(
{
"request_rate": [1.0, 5.0, 10.0, float("inf"), float("inf")],
"value": [10, 20, 30, 40, 50],
}
)
@pytest.mark.parametrize(
"target,expected_count",
[
("5.0", 1),
("10.0", 1),
("1.0", 1),
],
)
def test_equal_to_numeric(self, target, expected_count):
"""Test PlotEqualTo with numeric values."""
filter_obj = PlotEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
def test_equal_to_inf_float(self):
"""Test PlotEqualTo with float('inf')."""
filter_obj = PlotEqualTo("request_rate", "inf")
result = filter_obj.apply(self.df_inf_float)
# Should match both float('inf') entries because float('inf') == float('inf')
assert len(result) == 2
@pytest.mark.parametrize(
"target,expected_count",
[
("5.0", 4), # All except 5.0
("1.0", 4), # All except 1.0
],
)
def test_not_equal_to_numeric(self, target, expected_count):
"""Test PlotNotEqualTo with numeric values."""
filter_obj = PlotNotEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
def test_not_equal_to_inf_float(self):
"""Test PlotNotEqualTo with float('inf')."""
filter_obj = PlotNotEqualTo("request_rate", "inf")
result = filter_obj.apply(self.df_inf_float)
# Should exclude float('inf') entries
assert len(result) == 3
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 2), # 1.0, 5.0
("50.0", 3), # 1.0, 5.0, 10.0
("5.0", 1), # 1.0
],
)
def test_less_than(self, target, expected_count):
"""Test PlotLessThan with numeric values."""
filter_obj = PlotLessThan("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 3), # 1.0, 5.0, 10.0
("5.0", 2), # 1.0, 5.0
],
)
def test_less_than_or_equal_to(self, target, expected_count):
"""Test PlotLessThanOrEqualTo with numeric values."""
filter_obj = PlotLessThanOrEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 2), # 50.0, 100.0
("5.0", 3), # 10.0, 50.0, 100.0
],
)
def test_greater_than(self, target, expected_count):
"""Test PlotGreaterThan with numeric values."""
filter_obj = PlotGreaterThan("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"target,expected_count",
[
("10.0", 3), # 10.0, 50.0, 100.0
("5.0", 4), # 5.0, 10.0, 50.0, 100.0
],
)
def test_greater_than_or_equal_to(self, target, expected_count):
"""Test PlotGreaterThanOrEqualTo with numeric values."""
filter_obj = PlotGreaterThanOrEqualTo("request_rate", target)
result = filter_obj.apply(self.df_numeric)
assert len(result) == expected_count
@pytest.mark.parametrize(
"filter_str,expected_var,expected_target,expected_type",
[
("request_rate==5.0", "request_rate", "5.0", PlotEqualTo),
("request_rate!=10.0", "request_rate", "10.0", PlotNotEqualTo),
("request_rate<50.0", "request_rate", "50.0", PlotLessThan),
("request_rate<=50.0", "request_rate", "50.0", PlotLessThanOrEqualTo),
("request_rate>10.0", "request_rate", "10.0", PlotGreaterThan),
("request_rate>=10.0", "request_rate", "10.0", PlotGreaterThanOrEqualTo),
("request_rate==inf", "request_rate", "inf", PlotEqualTo),
("request_rate!='inf'", "request_rate", "inf", PlotNotEqualTo),
],
)
def test_parse_str(self, filter_str, expected_var, expected_target, expected_type):
"""Test parsing filter strings."""
filter_obj = PlotFilterBase.parse_str(filter_str)
assert isinstance(filter_obj, expected_type)
assert filter_obj.var == expected_var
assert filter_obj.target == expected_target
def test_parse_str_inf_edge_case(self):
"""Test parsing 'inf' string in filter."""
filter_obj = PlotFilterBase.parse_str("request_rate==inf")
assert isinstance(filter_obj, PlotEqualTo)
assert filter_obj.var == "request_rate"
assert filter_obj.target == "inf"
def test_parse_multiple_filters(self):
"""Test parsing multiple filters."""
filters = PlotFilters.parse_str("request_rate>5.0,value<=40")
assert len(filters) == 2
assert isinstance(filters[0], PlotGreaterThan)
assert isinstance(filters[1], PlotLessThanOrEqualTo)
def test_parse_empty_filter(self):
"""Test parsing empty filter string."""
filters = PlotFilters.parse_str("")
assert len(filters) == 0

View File

@ -9,8 +9,26 @@ class ParameterSweep(list["ParameterSweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
records = json.load(f)
data = json.load(f)
# Support both list and dict formats
if isinstance(data, dict):
return cls.read_from_dict(data)
return cls.from_records(data)
@classmethod
def read_from_dict(cls, data: dict[str, dict[str, object]]):
"""
Read parameter sweep from a dict format where keys are names.
Example:
{
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9}
}
"""
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
return cls.from_records(records)
@classmethod
@ -21,6 +39,15 @@ class ParameterSweep(list["ParameterSweepItem"]):
f"but found type: {type(records)}"
)
# Validate that all _benchmark_name values are unique if provided
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
if names and len(names) != len(set(names)):
duplicates = [name for name in names if names.count(name) > 1]
raise ValueError(
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
f"All _benchmark_name values must be unique."
)
return cls(ParameterSweepItem.from_record(record) for record in records)
@ -38,6 +65,18 @@ class ParameterSweepItem(dict[str, object]):
def __or__(self, other: dict[str, Any]):
return type(self)(super().__or__(other))
@property
def name(self) -> str:
"""
Get the name for this parameter sweep item.
Returns the '_benchmark_name' field if present, otherwise returns a text
representation of all parameters.
"""
if "_benchmark_name" in self:
return self["_benchmark_name"]
return self.as_text(sep="-")
# In JSON, we prefer "_"
def _iter_param_key_candidates(self, param_key: str):
# Inner config arguments are not converted by the CLI
@ -63,29 +102,57 @@ class ParameterSweepItem(dict[str, object]):
def has_param(self, param_key: str) -> bool:
return any(k in self for k in self._iter_param_key_candidates(param_key))
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
"""
Normalize a key-value pair into command-line arguments.
Returns a list containing either:
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
- Two elements for key-value pairs (e.g., ['--key', 'value'])
"""
if isinstance(v, bool):
# For nested params (containing "."), use =true/false syntax
if "." in k:
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
else:
return [self._normalize_cmd_key(k if v else "no-" + k)]
else:
return [self._normalize_cmd_key(k), str(v)]
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
cmd = list(cmd)
for k, v in self.items():
# Skip the '_benchmark_name' field, not a parameter
if k == "_benchmark_name":
continue
# Serialize dict values as JSON
if isinstance(v, dict):
v = json.dumps(v)
for k_candidate in self._iter_cmd_key_candidates(k):
try:
k_idx = cmd.index(k_candidate)
if isinstance(v, bool):
cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k)
# Replace existing parameter
normalized = self._normalize_cmd_kv_pair(k, v)
if len(normalized) == 1:
# Boolean flag
cmd[k_idx] = normalized[0]
else:
cmd[k_idx + 1] = str(v)
# Key-value pair
cmd[k_idx] = normalized[0]
cmd[k_idx + 1] = normalized[1]
break
except ValueError:
continue
else:
if isinstance(v, bool):
cmd.append(self._normalize_cmd_key(k if v else "no-" + k))
else:
cmd.extend([self._normalize_cmd_key(k), str(v)])
# Add new parameter
cmd.extend(self._normalize_cmd_kv_pair(k, v))
return cmd
def as_text(self, sep: str = ", ") -> str:
return sep.join(f"{k}={v}" for k, v in self.items())
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")

View File

@ -65,6 +65,18 @@ class PlotEqualTo(PlotFilterBase):
return df[df[self.var] == target]
@dataclass
class PlotNotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] != target]
@dataclass
class PlotLessThan(PlotFilterBase):
@override
@ -96,6 +108,7 @@ class PlotGreaterThanOrEqualTo(PlotFilterBase):
# NOTE: The ordering is important! Match longer op_keys first
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
"==": PlotEqualTo,
"!=": PlotNotEqualTo,
"<=": PlotLessThanOrEqualTo,
">=": PlotGreaterThanOrEqualTo,
"<": PlotLessThan,
@ -167,6 +180,27 @@ def _json_load_bytes(path: Path) -> list[dict[str, object]]:
return json.load(f)
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
"""
Convert string values "inf", "-inf", and "nan" to their float equivalents.
This handles the case where JSON serialization represents inf/nan as strings.
"""
converted_data = []
for record in data:
converted_record = {}
for key, value in record.items():
if isinstance(value, str):
if value in ["inf", "-inf", "nan"]:
converted_record[key] = float(value)
else:
converted_record[key] = value
else:
converted_record[key] = value
converted_data.append(converted_record)
return converted_data
def _get_metric(run_data: dict[str, object], metric_key: str):
try:
return run_data[metric_key]
@ -178,12 +212,15 @@ def _get_group(run_data: dict[str, object], group_keys: list[str]):
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]):
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
parts = list[str]()
# Start with figure name (always provided, defaults to "FIGURE")
parts.append(fig_name)
# Always append group data if present
if group:
parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group)))
else:
parts.append("figure")
parts.extend(f"{k}={v}" for k, v in group)
return fig_dir / sanitize_filename("-".join(parts) + ".png")
@ -217,6 +254,10 @@ def _plot_fig(
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str,
error_bars: bool,
fig_height: float,
fig_dpi: int,
):
fig_group, fig_data = fig_group_data
@ -230,7 +271,7 @@ def _plot_fig(
for _, row_data in row_groups
)
fig_path = _get_fig_path(fig_dir, fig_group)
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
@ -241,6 +282,8 @@ def _plot_fig(
print("[END FIGURE]")
return
# Convert string "inf", "-inf", and "nan" to their float equivalents
fig_data = _convert_inf_nan_strings(fig_data)
df = pd.DataFrame.from_records(fig_data)
if var_x not in df.columns:
@ -275,6 +318,10 @@ def _plot_fig(
df = filter_by.apply(df)
df = bin_by.apply(df)
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
df["row_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in row_by],
@ -293,7 +340,7 @@ def _plot_fig(
else "(All)"
)
g = sns.FacetGrid(df, row="row_group", col="col_group")
g = sns.FacetGrid(df, row="row_group", col="col_group", height=fig_height)
if row_by and col_by:
g.set_titles("{row_name}\n{col_name}")
@ -320,6 +367,7 @@ def _plot_fig(
style=style,
size=size,
markers=True,
errorbar="sd" if error_bars else None,
)
g.add_legend(title=hue)
@ -339,11 +387,12 @@ def _plot_fig(
y=var_y,
hue="curve_group",
markers=True,
errorbar="sd" if error_bars else None,
)
g.add_legend()
g.savefig(fig_path)
g.savefig(fig_path, dpi=fig_dpi)
plt.close(g.figure)
print("[END FIGURE]")
@ -364,6 +413,10 @@ def plot(
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str = "FIGURE",
error_bars: bool = True,
fig_height: float = 6.4,
fig_dpi: int = 300,
):
all_data = [
run_data
@ -398,6 +451,10 @@ def plot(
scale_x=scale_x,
scale_y=scale_y,
dry_run=dry_run,
fig_name=fig_name,
error_bars=error_bars,
fig_height=fig_height,
fig_dpi=fig_dpi,
),
fig_groups,
)
@ -419,6 +476,10 @@ class SweepPlotArgs:
scale_x: str | None
scale_y: str | None
dry_run: bool
fig_name: str = "FIGURE"
error_bars: bool = True
fig_height: float = 6.4
fig_dpi: int = 300
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
@ -448,6 +509,10 @@ class SweepPlotArgs:
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=not args.no_error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
@classmethod
@ -541,6 +606,32 @@ class SweepPlotArgs:
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--fig-name",
type=str,
default="FIGURE",
help="Name prefix for the output figure file. "
"Group data is always appended when present. "
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
)
parser.add_argument(
"--no-error-bars",
action="store_true",
help="If set, disables error bars on the plot. "
"By default, error bars are shown.",
)
parser.add_argument(
"--fig-height",
type=float,
default=6.4,
help="Height of each subplot in inches. Default: 6.4",
)
parser.add_argument(
"--fig-dpi",
type=int,
default=300,
help="Resolution of the output figure in dots per inch. Default: 300",
)
parser.add_argument(
"--dry-run",
action="store_true",
@ -566,6 +657,10 @@ def run_main(args: SweepPlotArgs):
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=args.error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)

View File

@ -138,9 +138,9 @@ def _get_comb_base_path(
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
parts.extend(("BENCH-", bench_comb.name))
return output_dir / sanitize_filename("-".join(parts))
@ -345,8 +345,9 @@ class SweepServeArgs:
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm serve` command. "
help="Path to JSON file containing parameter combinations "
"for the `vllm serve` command. Can be either a list of dicts or a dict "
"where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
@ -354,8 +355,9 @@ class SweepServeArgs:
"--bench-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm bench serve` command. "
help="Path to JSON file containing parameter combinations "
"for the `vllm bench serve` command. Can be either a list of dicts or "
"a dict where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)