[Frontend] Change CompilationMode to a proper Enum (#28165)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao 2025-11-11 16:46:18 -08:00 committed by GitHub
parent 1788aa1efb
commit 48c879369f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 108 additions and 23 deletions

View File

@ -127,7 +127,9 @@ def test_compile_correctness(
CompilationMode.VLLM_COMPILE,
]:
for mode in [CompilationMode.NONE, comp_mode]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"])
all_args.append(
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
)
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
@ -146,7 +148,7 @@ def test_compile_correctness(
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"])
all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({})

View File

@ -8,6 +8,7 @@ import os
import pytest
import yaml
from transformers import AutoTokenizer
from pydantic import ValidationError
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
@ -376,6 +377,65 @@ def test_load_config_file(tmp_path):
os.remove(str(config_file_path))
def test_compilation_mode_string_values(parser):
"""Test that -O.mode accepts both integer and string mode values."""
args = parser.parse_args(["-O.mode", "0"])
assert args.compilation_config == {"mode": 0}
args = parser.parse_args(["-O3"])
assert args.compilation_config == {"mode": 3}
args = parser.parse_args(["-O.mode=NONE"])
assert args.compilation_config == {"mode": "NONE"}
args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"])
assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}
args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"])
assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}
args = parser.parse_args(["-O.mode", "VLLM_COMPILE"])
assert args.compilation_config == {"mode": "VLLM_COMPILE"}
args = parser.parse_args(["-O.mode=none"])
assert args.compilation_config == {"mode": "none"}
args = parser.parse_args(["-O.mode=vllm_compile"])
assert args.compilation_config == {"mode": "vllm_compile"}
def test_compilation_config_mode_validator():
"""Test that CompilationConfig.mode field validator converts strings to integers."""
from vllm.config.compilation import CompilationConfig, CompilationMode
config = CompilationConfig(mode=0)
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode=3)
assert config.mode == CompilationMode.VLLM_COMPILE
config = CompilationConfig(mode="NONE")
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
assert config.mode == CompilationMode.STOCK_TORCH_COMPILE
config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE
config = CompilationConfig(mode="VLLM_COMPILE")
assert config.mode == CompilationMode.VLLM_COMPILE
config = CompilationConfig(mode="none")
assert config.mode == CompilationMode.NONE
config = CompilationConfig(mode="vllm_compile")
assert config.mode == CompilationMode.VLLM_COMPILE
with pytest.raises(ValidationError, match="Invalid compilation mode"):
CompilationConfig(mode="INVALID_MODE")
def test_flat_product():
# Check regular itertools.product behavior
result1 = list(flat_product([1, 2, 3], ["a", "b"]))

View File

@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher:
"""
def __init__(
self, compiled_callable: Callable | None = None, compilation_mode: int = 0
self,
compiled_callable: Callable | None = None,
compilation_mode: CompilationMode = CompilationMode.NONE,
):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config

View File

@ -28,7 +28,7 @@ else:
logger = init_logger(__name__)
class CompilationMode:
class CompilationMode(enum.IntEnum):
"""The compilation approach used for torch.compile-based compilation of the
model."""
@ -115,7 +115,7 @@ class PassConfig:
"""The threshold of the communicated tensor sizes under which
vllm should use flashinfer fused allreduce. Specified as a
float in MB.
Unspecified will fallback to default values
Unspecified will fallback to default values
which are compute capability and world size dependent.
FI_ALLREDUCE_FUSION_MAX_SIZE_MB = {
90: {
@ -244,7 +244,7 @@ class CompilationConfig:
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
mode: int | None = None
mode: CompilationMode | None = None
"""The compilation approach used for torch.compile-based compilation of the
model.
@ -377,23 +377,23 @@ class CompilationConfig:
FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends.
Generally for performance FULL_AND_PIECEWISE is better.
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
Mixed prefill-decode batches are run without cudagraphs. Can be good for
decode instances in a P/D setup where prefill is not as important so we
can save some memory.
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches.
This is the most performant mode for most models and is the default.
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
@ -422,7 +422,7 @@ class CompilationConfig:
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: bool | None = False
@ -451,7 +451,7 @@ class CompilationConfig:
outside the partition functions. For a graph with N cudagraph-unsafe ops
(e.g., Attention), there would be N+1 partitions. To mark an op as
cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
register the custom op.
register the custom op.
This config supports both full cudagraph and piecewise cudagraph without
compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
@ -468,8 +468,8 @@ class CompilationConfig:
max_cudagraph_capture_size: int | None = field(default=None)
"""The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest
If cudagraph_capture_sizes is specified, this will be set to the largest
size in that list (or checked for consistency if specified). If
cudagraph_capture_sizes is not specified, the list of sizes is generated
automatically following the pattern:
@ -478,7 +478,7 @@ class CompilationConfig:
range(256, max_cudagraph_capture_size + 1, 16))
If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2,
512) by default. This voids OOM in tight memory scenarios with small
512) by default. This voids OOM in tight memory scenarios with small
max_num_seqs, and prevents capture of many large graphs (>512) that would
greatly increase startup time with limited performance benefit.
"""
@ -579,6 +579,27 @@ class CompilationConfig:
__str__ = __repr__
@field_validator("mode", mode="before")
@classmethod
def validate_mode_before(cls, value: Any) -> Any:
"""
Enable parsing the `mode` field from string mode names.
Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE,
DYNAMO_TRACE_ONCE, VLLM_COMPILE.
"""
if isinstance(value, str):
# Convert string mode name to integer value
mode_name = value.upper()
if mode_name not in CompilationMode.__members__:
raise ValueError(
f"Invalid compilation mode: {value}. "
f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}"
)
return CompilationMode[mode_name]
return value
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
@ -904,7 +925,7 @@ class CompilationConfig:
return self.mode == CompilationMode.VLLM_COMPILE
# Inductor partition case
return self.backend == "inductor" and self.mode > CompilationMode.NONE
return self.backend == "inductor" and self.mode != CompilationMode.NONE
def custom_op_log_check(self):
"""

View File

@ -422,16 +422,13 @@ class VllmConfig:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else:
self.compilation_config.mode = CompilationMode.NONE
else:
assert self.compilation_config.mode >= CompilationMode.NONE
assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
# If user does not set custom ops via none or all set it here based on
# compilation mode and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
and self.compilation_config.mode > CompilationMode.NONE
and self.compilation_config.mode != CompilationMode.NONE
):
self.compilation_config.custom_ops.append("none")
else:

View File

@ -23,6 +23,7 @@ from vllm.config import (
StructuredOutputsConfig,
is_init_field,
)
from vllm.config.compilation import CompilationMode
from vllm.config.model import (
ConvertOption,
HfOverrides,
@ -259,7 +260,9 @@ class LLM:
if compilation_config is not None:
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(mode=compilation_config)
compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig(
**{