[Frontend] Change CompilationMode to a proper Enum (#28165)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao 2025-11-11 16:46:18 -08:00 committed by GitHub
parent 1788aa1efb
commit 48c879369f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 108 additions and 23 deletions

View File

@ -127,7 +127,9 @@ def test_compile_correctness(
CompilationMode.VLLM_COMPILE,
]:
for mode in [CompilationMode.NONE, comp_mode]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"])
all_args.append(
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
)
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
@ -146,7 +148,7 @@ def test_compile_correctness(
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"])
all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({})

View File

@ -8,6 +8,7 @@ import os
import pytest
import yaml
from transformers import AutoTokenizer
from pydantic import ValidationError
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
@ -376,6 +377,65 @@ def test_load_config_file(tmp_path):
os.remove(str(config_file_path))
def test_compilation_mode_string_values(parser):
"""Test that -O.mode accepts both integer and string mode values."""
args = parser.parse_args(["-O.mode", "0"])
assert args.compilation_config == {"mode": 0}
args = parser.parse_args(["-O3"])
assert args.compilation_config == {"mode": 3}
args = parser.parse_args(["-O.mode=NONE"])
assert args.compilation_config == {"mode": "NONE"}
args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"])
assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}
args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"])
assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}
args = parser.parse_args(["-O.mode", "VLLM_COMPILE"])
assert args.compilation_config == {"mode": "VLLM_COMPILE"}
args = parser.parse_args(["-O.mode=none"])
assert args.compilation_config == {"mode": "none"}
args = parser.parse_args(["-O.mode=vllm_compile"])
assert args.compilation_config == {"mode": "vllm_compile"}
def test_compilation_config_mode_validator():
"""Test that CompilationConfig.mode field validator converts strings to integers."""
from vllm.config.compilation import CompilationConfig, CompilationMode
config = CompilationConfig(mode=0)
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode=3)
assert config.mode == CompilationMode.VLLM_COMPILE
config = CompilationConfig(mode="NONE")
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
assert config.mode == CompilationMode.STOCK_TORCH_COMPILE
config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE
config = CompilationConfig(mode="VLLM_COMPILE")
assert config.mode == CompilationMode.VLLM_COMPILE
config = CompilationConfig(mode="none")
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode="vllm_compile")
assert config.mode == CompilationMode.VLLM_COMPILE
with pytest.raises(ValidationError, match="Invalid compilation mode"):
CompilationConfig(mode="INVALID_MODE")
def test_flat_product():
# Check regular itertools.product behavior
result1 = list(flat_product([1, 2, 3], ["a", "b"]))

View File

@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher:
"""
def __init__(
self, compiled_callable: Callable | None = None, compilation_mode: int = 0
self,
compiled_callable: Callable | None = None,
compilation_mode: CompilationMode = CompilationMode.NONE,
):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config

View File

@ -28,7 +28,7 @@ else:
logger = init_logger(__name__)
class CompilationMode:
class CompilationMode(enum.IntEnum):
"""The compilation approach used for torch.compile-based compilation of the
model."""
@ -244,7 +244,7 @@ class CompilationConfig:
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
mode: int | None = None
mode: CompilationMode | None = None
"""The compilation approach used for torch.compile-based compilation of the
model.
@ -579,6 +579,27 @@ class CompilationConfig:
__str__ = __repr__
@field_validator("mode", mode="before")
@classmethod
def validate_mode_before(cls, value: Any) -> Any:
"""
Enable parsing the `mode` field from string mode names.
Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE,
DYNAMO_TRACE_ONCE, VLLM_COMPILE.
"""
if isinstance(value, str):
# Convert string mode name to integer value
mode_name = value.upper()
if mode_name not in CompilationMode.__members__:
raise ValueError(
f"Invalid compilation mode: {value}. "
f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}"
)
return CompilationMode[mode_name]
return value
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
@ -904,7 +925,7 @@ class CompilationConfig:
return self.mode == CompilationMode.VLLM_COMPILE
# Inductor partition case
return self.backend == "inductor" and self.mode > CompilationMode.NONE
return self.backend == "inductor" and self.mode != CompilationMode.NONE
def custom_op_log_check(self):
"""

View File

@ -422,16 +422,13 @@ class VllmConfig:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else:
self.compilation_config.mode = CompilationMode.NONE
else:
assert self.compilation_config.mode >= CompilationMode.NONE
assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
# If user does not set custom ops via none or all set it here based on
# compilation mode and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
and self.compilation_config.mode > CompilationMode.NONE
and self.compilation_config.mode != CompilationMode.NONE
):
self.compilation_config.custom_ops.append("none")
else:

View File

@ -23,6 +23,7 @@ from vllm.config import (
StructuredOutputsConfig,
is_init_field,
)
from vllm.config.compilation import CompilationMode
from vllm.config.model import (
ConvertOption,
HfOverrides,
@ -259,7 +260,9 @@ class LLM:
if compilation_config is not None:
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(mode=compilation_config)
compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig(
**{