mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-04 02:37:03 +08:00
Fix boolean nested params, add dict format support, and enhance plotting for vllm bench sweep (#29025)
Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
a2b053dc85
commit
1c593e117d
257
tests/benchmarks/test_param_sweep.py
Normal file
257
tests/benchmarks/test_param_sweep.py
Normal file
@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.benchmarks.sweep.param_sweep import ParameterSweep, ParameterSweepItem
|
||||
|
||||
|
||||
class TestParameterSweepItem:
|
||||
"""Test ParameterSweepItem functionality."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected",
|
||||
[
|
||||
(
|
||||
{"compilation_config.use_inductor_graph_partition": False},
|
||||
"--compilation-config.use_inductor_graph_partition=false",
|
||||
),
|
||||
(
|
||||
{"compilation_config.use_inductor_graph_partition": True},
|
||||
"--compilation-config.use_inductor_graph_partition=true",
|
||||
),
|
||||
(
|
||||
{"compilation_config.use_inductor": False},
|
||||
"--compilation-config.use_inductor=false",
|
||||
),
|
||||
(
|
||||
{"compilation_config.use_inductor": True},
|
||||
"--compilation-config.use_inductor=true",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nested_boolean_params(self, input_dict, expected):
|
||||
"""Test that nested boolean params use =true/false syntax."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
|
||||
assert expected in cmd
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected",
|
||||
[
|
||||
({"enable_prefix_caching": False}, "--no-enable-prefix-caching"),
|
||||
({"enable_prefix_caching": True}, "--enable-prefix-caching"),
|
||||
({"disable_log_stats": False}, "--no-disable-log-stats"),
|
||||
({"disable_log_stats": True}, "--disable-log-stats"),
|
||||
],
|
||||
)
|
||||
def test_non_nested_boolean_params(self, input_dict, expected):
|
||||
"""Test that non-nested boolean params use --no- prefix."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
|
||||
assert expected in cmd
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config",
|
||||
[
|
||||
{"cudagraph_mode": "full", "mode": 2, "use_inductor_graph_partition": True},
|
||||
{
|
||||
"cudagraph_mode": "piecewise",
|
||||
"mode": 3,
|
||||
"use_inductor_graph_partition": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_nested_dict_value(self, compilation_config):
|
||||
"""Test that nested dict values are serialized as JSON."""
|
||||
item = ParameterSweepItem.from_record(
|
||||
{"compilation_config": compilation_config}
|
||||
)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "model"])
|
||||
assert "--compilation-config" in cmd
|
||||
# The dict should be JSON serialized
|
||||
idx = cmd.index("--compilation-config")
|
||||
assert json.loads(cmd[idx + 1]) == compilation_config
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected_key,expected_value",
|
||||
[
|
||||
({"model": "test-model"}, "--model", "test-model"),
|
||||
({"max_tokens": 100}, "--max-tokens", "100"),
|
||||
({"temperature": 0.7}, "--temperature", "0.7"),
|
||||
],
|
||||
)
|
||||
def test_string_and_numeric_values(self, input_dict, expected_key, expected_value):
|
||||
"""Test that string and numeric values are handled correctly."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve"])
|
||||
assert expected_key in cmd
|
||||
assert expected_value in cmd
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_dict,expected_key,key_idx_offset",
|
||||
[
|
||||
({"max_tokens": 200}, "--max-tokens", 1),
|
||||
({"enable_prefix_caching": False}, "--no-enable-prefix-caching", 0),
|
||||
],
|
||||
)
|
||||
def test_replace_existing_parameter(self, input_dict, expected_key, key_idx_offset):
|
||||
"""Test that existing parameters in cmd are replaced."""
|
||||
item = ParameterSweepItem.from_record(input_dict)
|
||||
|
||||
if key_idx_offset == 1:
|
||||
# Key-value pair
|
||||
cmd = item.apply_to_cmd(["vllm", "serve", "--max-tokens", "100", "model"])
|
||||
assert expected_key in cmd
|
||||
idx = cmd.index(expected_key)
|
||||
assert cmd[idx + 1] == "200"
|
||||
assert "100" not in cmd
|
||||
else:
|
||||
# Boolean flag
|
||||
cmd = item.apply_to_cmd(
|
||||
["vllm", "serve", "--enable-prefix-caching", "model"]
|
||||
)
|
||||
assert expected_key in cmd
|
||||
assert "--enable-prefix-caching" not in cmd
|
||||
|
||||
|
||||
class TestParameterSweep:
|
||||
"""Test ParameterSweep functionality."""
|
||||
|
||||
def test_from_records_list(self):
|
||||
"""Test creating ParameterSweep from a list of records."""
|
||||
records = [
|
||||
{"max_tokens": 100, "temperature": 0.7},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
sweep = ParameterSweep.from_records(records)
|
||||
assert len(sweep) == 2
|
||||
assert sweep[0]["max_tokens"] == 100
|
||||
assert sweep[1]["max_tokens"] == 200
|
||||
|
||||
def test_read_from_dict(self):
|
||||
"""Test creating ParameterSweep from a dict format."""
|
||||
data = {
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9},
|
||||
}
|
||||
sweep = ParameterSweep.read_from_dict(data)
|
||||
assert len(sweep) == 2
|
||||
|
||||
# Check that items have the _benchmark_name field
|
||||
names = {item["_benchmark_name"] for item in sweep}
|
||||
assert names == {"experiment1", "experiment2"}
|
||||
|
||||
# Check that parameters are preserved
|
||||
for item in sweep:
|
||||
if item["_benchmark_name"] == "experiment1":
|
||||
assert item["max_tokens"] == 100
|
||||
assert item["temperature"] == 0.7
|
||||
elif item["_benchmark_name"] == "experiment2":
|
||||
assert item["max_tokens"] == 200
|
||||
assert item["temperature"] == 0.9
|
||||
|
||||
def test_read_json_list_format(self):
|
||||
"""Test reading JSON file with list format."""
|
||||
records = [
|
||||
{"max_tokens": 100, "temperature": 0.7},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(records, f)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
sweep = ParameterSweep.read_json(temp_path)
|
||||
assert len(sweep) == 2
|
||||
assert sweep[0]["max_tokens"] == 100
|
||||
assert sweep[1]["max_tokens"] == 200
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_read_json_dict_format(self):
|
||||
"""Test reading JSON file with dict format."""
|
||||
data = {
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9},
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(data, f)
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
sweep = ParameterSweep.read_json(temp_path)
|
||||
assert len(sweep) == 2
|
||||
|
||||
# Check that items have the _benchmark_name field
|
||||
names = {item["_benchmark_name"] for item in sweep}
|
||||
assert names == {"experiment1", "experiment2"}
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
def test_unique_benchmark_names_validation(self):
|
||||
"""Test that duplicate _benchmark_name values raise an error."""
|
||||
# Test with duplicate names in list format
|
||||
records = [
|
||||
{"_benchmark_name": "exp1", "max_tokens": 100},
|
||||
{"_benchmark_name": "exp1", "max_tokens": 200},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
|
||||
ParameterSweep.from_records(records)
|
||||
|
||||
def test_unique_benchmark_names_multiple_duplicates(self):
|
||||
"""Test validation with multiple duplicate names."""
|
||||
records = [
|
||||
{"_benchmark_name": "exp1", "max_tokens": 100},
|
||||
{"_benchmark_name": "exp1", "max_tokens": 200},
|
||||
{"_benchmark_name": "exp2", "max_tokens": 300},
|
||||
{"_benchmark_name": "exp2", "max_tokens": 400},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate _benchmark_name values"):
|
||||
ParameterSweep.from_records(records)
|
||||
|
||||
def test_no_benchmark_names_allowed(self):
|
||||
"""Test that records without _benchmark_name are allowed."""
|
||||
records = [
|
||||
{"max_tokens": 100, "temperature": 0.7},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
sweep = ParameterSweep.from_records(records)
|
||||
assert len(sweep) == 2
|
||||
|
||||
def test_mixed_benchmark_names_allowed(self):
|
||||
"""Test that mixing records with and without _benchmark_name is allowed."""
|
||||
records = [
|
||||
{"_benchmark_name": "exp1", "max_tokens": 100},
|
||||
{"max_tokens": 200, "temperature": 0.9},
|
||||
]
|
||||
sweep = ParameterSweep.from_records(records)
|
||||
assert len(sweep) == 2
|
||||
|
||||
|
||||
class TestParameterSweepItemKeyNormalization:
|
||||
"""Test key normalization in ParameterSweepItem."""
|
||||
|
||||
def test_underscore_to_hyphen_conversion(self):
|
||||
"""Test that underscores are converted to hyphens in CLI."""
|
||||
item = ParameterSweepItem.from_record({"max_tokens": 100})
|
||||
cmd = item.apply_to_cmd(["vllm", "serve"])
|
||||
assert "--max-tokens" in cmd
|
||||
|
||||
def test_nested_key_preserves_suffix(self):
|
||||
"""Test that nested keys preserve the suffix format."""
|
||||
# The suffix after the dot should preserve underscores
|
||||
item = ParameterSweepItem.from_record(
|
||||
{"compilation_config.some_nested_param": "value"}
|
||||
)
|
||||
cmd = item.apply_to_cmd(["vllm", "serve"])
|
||||
# The prefix (compilation_config) gets converted to hyphens,
|
||||
# but the suffix (some_nested_param) is preserved
|
||||
assert any("compilation-config.some_nested_param" in arg for arg in cmd)
|
||||
171
tests/benchmarks/test_plot_filters.py
Normal file
171
tests/benchmarks/test_plot_filters.py
Normal file
@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from vllm.benchmarks.sweep.plot import (
|
||||
PlotEqualTo,
|
||||
PlotFilterBase,
|
||||
PlotFilters,
|
||||
PlotGreaterThan,
|
||||
PlotGreaterThanOrEqualTo,
|
||||
PlotLessThan,
|
||||
PlotLessThanOrEqualTo,
|
||||
PlotNotEqualTo,
|
||||
)
|
||||
|
||||
|
||||
class TestPlotFilters:
|
||||
"""Test PlotFilter functionality including 'inf' edge case."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Create sample DataFrames for testing."""
|
||||
# DataFrame with numeric values
|
||||
self.df_numeric = pd.DataFrame(
|
||||
{
|
||||
"request_rate": [1.0, 5.0, 10.0, 50.0, 100.0],
|
||||
"value": [10, 20, 30, 40, 50],
|
||||
}
|
||||
)
|
||||
|
||||
# DataFrame with float('inf') - note: string "inf" values are coerced
|
||||
# to float when loading data, so we only test with float('inf')
|
||||
self.df_inf_float = pd.DataFrame(
|
||||
{
|
||||
"request_rate": [1.0, 5.0, 10.0, float("inf"), float("inf")],
|
||||
"value": [10, 20, 30, 40, 50],
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("5.0", 1),
|
||||
("10.0", 1),
|
||||
("1.0", 1),
|
||||
],
|
||||
)
|
||||
def test_equal_to_numeric(self, target, expected_count):
|
||||
"""Test PlotEqualTo with numeric values."""
|
||||
filter_obj = PlotEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
def test_equal_to_inf_float(self):
|
||||
"""Test PlotEqualTo with float('inf')."""
|
||||
filter_obj = PlotEqualTo("request_rate", "inf")
|
||||
result = filter_obj.apply(self.df_inf_float)
|
||||
# Should match both float('inf') entries because float('inf') == float('inf')
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("5.0", 4), # All except 5.0
|
||||
("1.0", 4), # All except 1.0
|
||||
],
|
||||
)
|
||||
def test_not_equal_to_numeric(self, target, expected_count):
|
||||
"""Test PlotNotEqualTo with numeric values."""
|
||||
filter_obj = PlotNotEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
def test_not_equal_to_inf_float(self):
|
||||
"""Test PlotNotEqualTo with float('inf')."""
|
||||
filter_obj = PlotNotEqualTo("request_rate", "inf")
|
||||
result = filter_obj.apply(self.df_inf_float)
|
||||
# Should exclude float('inf') entries
|
||||
assert len(result) == 3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 2), # 1.0, 5.0
|
||||
("50.0", 3), # 1.0, 5.0, 10.0
|
||||
("5.0", 1), # 1.0
|
||||
],
|
||||
)
|
||||
def test_less_than(self, target, expected_count):
|
||||
"""Test PlotLessThan with numeric values."""
|
||||
filter_obj = PlotLessThan("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 3), # 1.0, 5.0, 10.0
|
||||
("5.0", 2), # 1.0, 5.0
|
||||
],
|
||||
)
|
||||
def test_less_than_or_equal_to(self, target, expected_count):
|
||||
"""Test PlotLessThanOrEqualTo with numeric values."""
|
||||
filter_obj = PlotLessThanOrEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 2), # 50.0, 100.0
|
||||
("5.0", 3), # 10.0, 50.0, 100.0
|
||||
],
|
||||
)
|
||||
def test_greater_than(self, target, expected_count):
|
||||
"""Test PlotGreaterThan with numeric values."""
|
||||
filter_obj = PlotGreaterThan("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target,expected_count",
|
||||
[
|
||||
("10.0", 3), # 10.0, 50.0, 100.0
|
||||
("5.0", 4), # 5.0, 10.0, 50.0, 100.0
|
||||
],
|
||||
)
|
||||
def test_greater_than_or_equal_to(self, target, expected_count):
|
||||
"""Test PlotGreaterThanOrEqualTo with numeric values."""
|
||||
filter_obj = PlotGreaterThanOrEqualTo("request_rate", target)
|
||||
result = filter_obj.apply(self.df_numeric)
|
||||
assert len(result) == expected_count
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_str,expected_var,expected_target,expected_type",
|
||||
[
|
||||
("request_rate==5.0", "request_rate", "5.0", PlotEqualTo),
|
||||
("request_rate!=10.0", "request_rate", "10.0", PlotNotEqualTo),
|
||||
("request_rate<50.0", "request_rate", "50.0", PlotLessThan),
|
||||
("request_rate<=50.0", "request_rate", "50.0", PlotLessThanOrEqualTo),
|
||||
("request_rate>10.0", "request_rate", "10.0", PlotGreaterThan),
|
||||
("request_rate>=10.0", "request_rate", "10.0", PlotGreaterThanOrEqualTo),
|
||||
("request_rate==inf", "request_rate", "inf", PlotEqualTo),
|
||||
("request_rate!='inf'", "request_rate", "inf", PlotNotEqualTo),
|
||||
],
|
||||
)
|
||||
def test_parse_str(self, filter_str, expected_var, expected_target, expected_type):
|
||||
"""Test parsing filter strings."""
|
||||
filter_obj = PlotFilterBase.parse_str(filter_str)
|
||||
assert isinstance(filter_obj, expected_type)
|
||||
assert filter_obj.var == expected_var
|
||||
assert filter_obj.target == expected_target
|
||||
|
||||
def test_parse_str_inf_edge_case(self):
|
||||
"""Test parsing 'inf' string in filter."""
|
||||
filter_obj = PlotFilterBase.parse_str("request_rate==inf")
|
||||
assert isinstance(filter_obj, PlotEqualTo)
|
||||
assert filter_obj.var == "request_rate"
|
||||
assert filter_obj.target == "inf"
|
||||
|
||||
def test_parse_multiple_filters(self):
|
||||
"""Test parsing multiple filters."""
|
||||
filters = PlotFilters.parse_str("request_rate>5.0,value<=40")
|
||||
assert len(filters) == 2
|
||||
assert isinstance(filters[0], PlotGreaterThan)
|
||||
assert isinstance(filters[1], PlotLessThanOrEqualTo)
|
||||
|
||||
def test_parse_empty_filter(self):
|
||||
"""Test parsing empty filter string."""
|
||||
filters = PlotFilters.parse_str("")
|
||||
assert len(filters) == 0
|
||||
@ -9,8 +9,26 @@ class ParameterSweep(list["ParameterSweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
records = json.load(f)
|
||||
data = json.load(f)
|
||||
|
||||
# Support both list and dict formats
|
||||
if isinstance(data, dict):
|
||||
return cls.read_from_dict(data)
|
||||
|
||||
return cls.from_records(data)
|
||||
|
||||
@classmethod
|
||||
def read_from_dict(cls, data: dict[str, dict[str, object]]):
|
||||
"""
|
||||
Read parameter sweep from a dict format where keys are names.
|
||||
|
||||
Example:
|
||||
{
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9}
|
||||
}
|
||||
"""
|
||||
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
@ -21,6 +39,15 @@ class ParameterSweep(list["ParameterSweepItem"]):
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
# Validate that all _benchmark_name values are unique if provided
|
||||
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
|
||||
if names and len(names) != len(set(names)):
|
||||
duplicates = [name for name in names if names.count(name) > 1]
|
||||
raise ValueError(
|
||||
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
|
||||
f"All _benchmark_name values must be unique."
|
||||
)
|
||||
|
||||
return cls(ParameterSweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
@ -38,6 +65,18 @@ class ParameterSweepItem(dict[str, object]):
|
||||
def __or__(self, other: dict[str, Any]):
|
||||
return type(self)(super().__or__(other))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name for this parameter sweep item.
|
||||
|
||||
Returns the '_benchmark_name' field if present, otherwise returns a text
|
||||
representation of all parameters.
|
||||
"""
|
||||
if "_benchmark_name" in self:
|
||||
return self["_benchmark_name"]
|
||||
return self.as_text(sep="-")
|
||||
|
||||
# In JSON, we prefer "_"
|
||||
def _iter_param_key_candidates(self, param_key: str):
|
||||
# Inner config arguments are not converted by the CLI
|
||||
@ -63,29 +102,57 @@ class ParameterSweepItem(dict[str, object]):
|
||||
def has_param(self, param_key: str) -> bool:
|
||||
return any(k in self for k in self._iter_param_key_candidates(param_key))
|
||||
|
||||
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
|
||||
"""
|
||||
Normalize a key-value pair into command-line arguments.
|
||||
|
||||
Returns a list containing either:
|
||||
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
|
||||
- Two elements for key-value pairs (e.g., ['--key', 'value'])
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
# For nested params (containing "."), use =true/false syntax
|
||||
if "." in k:
|
||||
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k if v else "no-" + k)]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k), str(v)]
|
||||
|
||||
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
|
||||
cmd = list(cmd)
|
||||
|
||||
for k, v in self.items():
|
||||
# Skip the '_benchmark_name' field, not a parameter
|
||||
if k == "_benchmark_name":
|
||||
continue
|
||||
|
||||
# Serialize dict values as JSON
|
||||
if isinstance(v, dict):
|
||||
v = json.dumps(v)
|
||||
|
||||
for k_candidate in self._iter_cmd_key_candidates(k):
|
||||
try:
|
||||
k_idx = cmd.index(k_candidate)
|
||||
|
||||
if isinstance(v, bool):
|
||||
cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k)
|
||||
# Replace existing parameter
|
||||
normalized = self._normalize_cmd_kv_pair(k, v)
|
||||
if len(normalized) == 1:
|
||||
# Boolean flag
|
||||
cmd[k_idx] = normalized[0]
|
||||
else:
|
||||
cmd[k_idx + 1] = str(v)
|
||||
# Key-value pair
|
||||
cmd[k_idx] = normalized[0]
|
||||
cmd[k_idx + 1] = normalized[1]
|
||||
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
if isinstance(v, bool):
|
||||
cmd.append(self._normalize_cmd_key(k if v else "no-" + k))
|
||||
else:
|
||||
cmd.extend([self._normalize_cmd_key(k), str(v)])
|
||||
# Add new parameter
|
||||
cmd.extend(self._normalize_cmd_kv_pair(k, v))
|
||||
|
||||
return cmd
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(f"{k}={v}" for k, v in self.items())
|
||||
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")
|
||||
|
||||
@ -65,6 +65,18 @@ class PlotEqualTo(PlotFilterBase):
|
||||
return df[df[self.var] == target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotNotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] != target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
@ -96,6 +108,7 @@ class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
|
||||
"==": PlotEqualTo,
|
||||
"!=": PlotNotEqualTo,
|
||||
"<=": PlotLessThanOrEqualTo,
|
||||
">=": PlotGreaterThanOrEqualTo,
|
||||
"<": PlotLessThan,
|
||||
@ -167,6 +180,27 @@ def _json_load_bytes(path: Path) -> list[dict[str, object]]:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Convert string values "inf", "-inf", and "nan" to their float equivalents.
|
||||
|
||||
This handles the case where JSON serialization represents inf/nan as strings.
|
||||
"""
|
||||
converted_data = []
|
||||
for record in data:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
if isinstance(value, str):
|
||||
if value in ["inf", "-inf", "nan"]:
|
||||
converted_record[key] = float(value)
|
||||
else:
|
||||
converted_record[key] = value
|
||||
else:
|
||||
converted_record[key] = value
|
||||
converted_data.append(converted_record)
|
||||
return converted_data
|
||||
|
||||
|
||||
def _get_metric(run_data: dict[str, object], metric_key: str):
|
||||
try:
|
||||
return run_data[metric_key]
|
||||
@ -178,12 +212,15 @@ def _get_group(run_data: dict[str, object], group_keys: list[str]):
|
||||
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
|
||||
|
||||
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]):
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
|
||||
parts = list[str]()
|
||||
|
||||
# Start with figure name (always provided, defaults to "FIGURE")
|
||||
parts.append(fig_name)
|
||||
|
||||
# Always append group data if present
|
||||
if group:
|
||||
parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group)))
|
||||
else:
|
||||
parts.append("figure")
|
||||
parts.extend(f"{k}={v}" for k, v in group)
|
||||
|
||||
return fig_dir / sanitize_filename("-".join(parts) + ".png")
|
||||
|
||||
@ -217,6 +254,10 @@ def _plot_fig(
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str,
|
||||
error_bars: bool,
|
||||
fig_height: float,
|
||||
fig_dpi: int,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
|
||||
@ -230,7 +271,7 @@ def _plot_fig(
|
||||
for _, row_data in row_groups
|
||||
)
|
||||
|
||||
fig_path = _get_fig_path(fig_dir, fig_group)
|
||||
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
@ -241,6 +282,8 @@ def _plot_fig(
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Convert string "inf", "-inf", and "nan" to their float equivalents
|
||||
fig_data = _convert_inf_nan_strings(fig_data)
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
|
||||
if var_x not in df.columns:
|
||||
@ -275,6 +318,10 @@ def _plot_fig(
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
|
||||
df["row_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in row_by],
|
||||
@ -293,7 +340,7 @@ def _plot_fig(
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g = sns.FacetGrid(df, row="row_group", col="col_group")
|
||||
g = sns.FacetGrid(df, row="row_group", col="col_group", height=fig_height)
|
||||
|
||||
if row_by and col_by:
|
||||
g.set_titles("{row_name}\n{col_name}")
|
||||
@ -320,6 +367,7 @@ def _plot_fig(
|
||||
style=style,
|
||||
size=size,
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
)
|
||||
|
||||
g.add_legend(title=hue)
|
||||
@ -339,11 +387,12 @@ def _plot_fig(
|
||||
y=var_y,
|
||||
hue="curve_group",
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
)
|
||||
|
||||
g.add_legend()
|
||||
|
||||
g.savefig(fig_path)
|
||||
g.savefig(fig_path, dpi=fig_dpi)
|
||||
plt.close(g.figure)
|
||||
|
||||
print("[END FIGURE]")
|
||||
@ -364,6 +413,10 @@ def plot(
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str = "FIGURE",
|
||||
error_bars: bool = True,
|
||||
fig_height: float = 6.4,
|
||||
fig_dpi: int = 300,
|
||||
):
|
||||
all_data = [
|
||||
run_data
|
||||
@ -398,6 +451,10 @@ def plot(
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
dry_run=dry_run,
|
||||
fig_name=fig_name,
|
||||
error_bars=error_bars,
|
||||
fig_height=fig_height,
|
||||
fig_dpi=fig_dpi,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
@ -419,6 +476,10 @@ class SweepPlotArgs:
|
||||
scale_x: str | None
|
||||
scale_y: str | None
|
||||
dry_run: bool
|
||||
fig_name: str = "FIGURE"
|
||||
error_bars: bool = True
|
||||
fig_height: float = 6.4
|
||||
fig_dpi: int = 300
|
||||
|
||||
parser_name: ClassVar[str] = "plot"
|
||||
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
|
||||
@ -448,6 +509,10 @@ class SweepPlotArgs:
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=not args.no_error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -541,6 +606,32 @@ class SweepPlotArgs:
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-name",
|
||||
type=str,
|
||||
default="FIGURE",
|
||||
help="Name prefix for the output figure file. "
|
||||
"Group data is always appended when present. "
|
||||
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-error-bars",
|
||||
action="store_true",
|
||||
help="If set, disables error bars on the plot. "
|
||||
"By default, error bars are shown.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-height",
|
||||
type=float,
|
||||
default=6.4,
|
||||
help="Height of each subplot in inches. Default: 6.4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dpi",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Resolution of the output figure in dots per inch. Default: 300",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
@ -566,6 +657,10 @@ def run_main(args: SweepPlotArgs):
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=args.error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -138,9 +138,9 @@ def _get_comb_base_path(
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
@ -345,8 +345,9 @@ class SweepServeArgs:
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing a list of parameter combinations "
|
||||
"for the `vllm serve` command. "
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Can be either a list of dicts or a dict "
|
||||
"where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
@ -354,8 +355,9 @@ class SweepServeArgs:
|
||||
"--bench-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing a list of parameter combinations "
|
||||
"for the `vllm bench serve` command. "
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench serve` command. Can be either a list of dicts or "
|
||||
"a dict where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user