[Frontend] Change CompilationMode to a proper Enum (#28165)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao 2025-11-11 16:46:18 -08:00 committed by GitHub
parent 1788aa1efb
commit 48c879369f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 108 additions and 23 deletions

View File

@ -127,7 +127,9 @@ def test_compile_correctness(
CompilationMode.VLLM_COMPILE, CompilationMode.VLLM_COMPILE,
]: ]:
for mode in [CompilationMode.NONE, comp_mode]: for mode in [CompilationMode.NONE, comp_mode]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"]) all_args.append(
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
)
# inductor will change the output, so we only compare if the output # inductor will change the output, so we only compare if the output
# is close, not exactly the same. # is close, not exactly the same.
@ -146,7 +148,7 @@ def test_compile_correctness(
CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE, CompilationMode.VLLM_COMPILE,
]: ]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"]) all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
all_envs.append({}) all_envs.append({})
all_envs.append({}) all_envs.append({})

View File

@ -8,6 +8,7 @@ import os
import pytest import pytest
import yaml import yaml
from transformers import AutoTokenizer from transformers import AutoTokenizer
from pydantic import ValidationError
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
@ -376,6 +377,65 @@ def test_load_config_file(tmp_path):
os.remove(str(config_file_path)) os.remove(str(config_file_path))
def test_compilation_mode_string_values(parser):
"""Test that -O.mode accepts both integer and string mode values."""
args = parser.parse_args(["-O.mode", "0"])
assert args.compilation_config == {"mode": 0}
args = parser.parse_args(["-O3"])
assert args.compilation_config == {"mode": 3}
args = parser.parse_args(["-O.mode=NONE"])
assert args.compilation_config == {"mode": "NONE"}
args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"])
assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}
args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"])
assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}
args = parser.parse_args(["-O.mode", "VLLM_COMPILE"])
assert args.compilation_config == {"mode": "VLLM_COMPILE"}
args = parser.parse_args(["-O.mode=none"])
assert args.compilation_config == {"mode": "none"}
args = parser.parse_args(["-O.mode=vllm_compile"])
assert args.compilation_config == {"mode": "vllm_compile"}
def test_compilation_config_mode_validator():
"""Test that CompilationConfig.mode field validator converts strings to integers."""
from vllm.config.compilation import CompilationConfig, CompilationMode
config = CompilationConfig(mode=0)
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode=3)
assert config.mode == CompilationMode.VLLM_COMPILE
config = CompilationConfig(mode="NONE")
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
assert config.mode == CompilationMode.STOCK_TORCH_COMPILE
config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE
config = CompilationConfig(mode="VLLM_COMPILE")
assert config.mode == CompilationMode.VLLM_COMPILE
config = CompilationConfig(mode="none")
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode="vllm_compile")
assert config.mode == CompilationMode.VLLM_COMPILE
with pytest.raises(ValidationError, match="Invalid compilation mode"):
CompilationConfig(mode="INVALID_MODE")
def test_flat_product(): def test_flat_product():
# Check regular itertools.product behavior # Check regular itertools.product behavior
result1 = list(flat_product([1, 2, 3], ["a", "b"])) result1 = list(flat_product([1, 2, 3], ["a", "b"]))

View File

@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher:
""" """
def __init__( def __init__(
self, compiled_callable: Callable | None = None, compilation_mode: int = 0 self,
compiled_callable: Callable | None = None,
compilation_mode: CompilationMode = CompilationMode.NONE,
): ):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config self.vllm_config = vllm_config

View File

@ -28,7 +28,7 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
class CompilationMode: class CompilationMode(enum.IntEnum):
"""The compilation approach used for torch.compile-based compilation of the """The compilation approach used for torch.compile-based compilation of the
model.""" model."""
@ -244,7 +244,7 @@ class CompilationConfig:
Please use mode. Currently all levels are mapped to mode. Please use mode. Currently all levels are mapped to mode.
""" """
# Top-level Compilation control # Top-level Compilation control
mode: int | None = None mode: CompilationMode | None = None
"""The compilation approach used for torch.compile-based compilation of the """The compilation approach used for torch.compile-based compilation of the
model. model.
@ -579,6 +579,27 @@ class CompilationConfig:
__str__ = __repr__ __str__ = __repr__
@field_validator("mode", mode="before")
@classmethod
def validate_mode_before(cls, value: Any) -> Any:
"""
Enable parsing the `mode` field from string mode names.
Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE,
DYNAMO_TRACE_ONCE, VLLM_COMPILE.
"""
if isinstance(value, str):
# Convert string mode name to integer value
mode_name = value.upper()
if mode_name not in CompilationMode.__members__:
raise ValueError(
f"Invalid compilation mode: {value}. "
f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}"
)
return CompilationMode[mode_name]
return value
@field_validator("cudagraph_mode", mode="before") @field_validator("cudagraph_mode", mode="before")
@classmethod @classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any: def validate_cudagraph_mode_before(cls, value: Any) -> Any:
@ -904,7 +925,7 @@ class CompilationConfig:
return self.mode == CompilationMode.VLLM_COMPILE return self.mode == CompilationMode.VLLM_COMPILE
# Inductor partition case # Inductor partition case
return self.backend == "inductor" and self.mode > CompilationMode.NONE return self.backend == "inductor" and self.mode != CompilationMode.NONE
def custom_op_log_check(self): def custom_op_log_check(self):
""" """

View File

@ -422,16 +422,13 @@ class VllmConfig:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else: else:
self.compilation_config.mode = CompilationMode.NONE self.compilation_config.mode = CompilationMode.NONE
else:
assert self.compilation_config.mode >= CompilationMode.NONE
assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
# If user does not set custom ops via none or all set it here based on # If user does not set custom ops via none or all set it here based on
# compilation mode and backend. # compilation mode and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if ( if (
self.compilation_config.backend == "inductor" self.compilation_config.backend == "inductor"
and self.compilation_config.mode > CompilationMode.NONE and self.compilation_config.mode != CompilationMode.NONE
): ):
self.compilation_config.custom_ops.append("none") self.compilation_config.custom_ops.append("none")
else: else:

View File

@ -23,6 +23,7 @@ from vllm.config import (
StructuredOutputsConfig, StructuredOutputsConfig,
is_init_field, is_init_field,
) )
from vllm.config.compilation import CompilationMode
from vllm.config.model import ( from vllm.config.model import (
ConvertOption, ConvertOption,
HfOverrides, HfOverrides,
@ -259,7 +260,9 @@ class LLM:
if compilation_config is not None: if compilation_config is not None:
if isinstance(compilation_config, int): if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(mode=compilation_config) compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
elif isinstance(compilation_config, dict): elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig( compilation_config_instance = CompilationConfig(
**{ **{