diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 132a838b8d44..3f6898607f6b 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -127,7 +127,9 @@ def test_compile_correctness( CompilationMode.VLLM_COMPILE, ]: for mode in [CompilationMode.NONE, comp_mode]: - all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"]) + all_args.append( + final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"] + ) # inductor will change the output, so we only compare if the output # is close, not exactly the same. @@ -146,7 +148,7 @@ def test_compile_correctness( CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE, ]: - all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"]) + all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"]) all_envs.append({}) all_envs.append({}) diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py index 51684edcc8a3..3310753d2b6d 100644 --- a/tests/utils_/test_argparse_utils.py +++ b/tests/utils_/test_argparse_utils.py @@ -8,6 +8,7 @@ import os import pytest import yaml from transformers import AutoTokenizer +from pydantic import ValidationError from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens @@ -376,6 +377,65 @@ def test_load_config_file(tmp_path): os.remove(str(config_file_path)) +def test_compilation_mode_string_values(parser): + """Test that -O.mode accepts both integer and string mode values.""" + args = parser.parse_args(["-O.mode", "0"]) + assert args.compilation_config == {"mode": 0} + + args = parser.parse_args(["-O3"]) + assert args.compilation_config == {"mode": 3} + + args = parser.parse_args(["-O.mode=NONE"]) + assert args.compilation_config == {"mode": "NONE"} + + args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"]) + assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"} + + args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"]) + assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"} + + args = parser.parse_args(["-O.mode", "VLLM_COMPILE"]) + assert args.compilation_config == {"mode": "VLLM_COMPILE"} + + args = parser.parse_args(["-O.mode=none"]) + assert args.compilation_config == {"mode": "none"} + + args = parser.parse_args(["-O.mode=vllm_compile"]) + assert args.compilation_config == {"mode": "vllm_compile"} + + +def test_compilation_config_mode_validator(): + """Test that CompilationConfig.mode field validator converts strings to integers.""" + from vllm.config.compilation import CompilationConfig, CompilationMode + + config = CompilationConfig(mode=0) + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode=3) + assert config.mode == CompilationMode.VLLM_COMPILE + + config = CompilationConfig(mode="NONE") + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode="STOCK_TORCH_COMPILE") + assert config.mode == CompilationMode.STOCK_TORCH_COMPILE + + config = CompilationConfig(mode="DYNAMO_TRACE_ONCE") + assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE + + config = CompilationConfig(mode="VLLM_COMPILE") + assert config.mode == CompilationMode.VLLM_COMPILE + + config = CompilationConfig(mode="none") + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode="vllm_compile") + assert config.mode == CompilationMode.VLLM_COMPILE + + with pytest.raises(ValidationError, match="Invalid compilation mode"): + CompilationConfig(mode="INVALID_MODE") + + def test_flat_product(): # Check regular itertools.product behavior result1 = list(flat_product([1, 2, 3], ["a", "b"])) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4b10c85209f6..4d26619bd128 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher: """ def __init__( - self, compiled_callable: Callable | None = None, compilation_mode: int = 0 + self, + compiled_callable: Callable | None = None, + compilation_mode: CompilationMode = CompilationMode.NONE, ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9c9557df4e73..e1d60ee84d89 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -28,7 +28,7 @@ else: logger = init_logger(__name__) -class CompilationMode: +class CompilationMode(enum.IntEnum): """The compilation approach used for torch.compile-based compilation of the model.""" @@ -115,7 +115,7 @@ class PassConfig: """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a float in MB. - Unspecified will fallback to default values + Unspecified will fallback to default values which are compute capability and world size dependent. FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { 90: { @@ -244,7 +244,7 @@ class CompilationConfig: Please use mode. Currently all levels are mapped to mode. """ # Top-level Compilation control - mode: int | None = None + mode: CompilationMode | None = None """The compilation approach used for torch.compile-based compilation of the model. @@ -377,23 +377,23 @@ class CompilationConfig: FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. Generally for performance FULL_AND_PIECEWISE is better. - + FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only. Mixed prefill-decode batches are run without cudagraphs. Can be good for decode instances in a P/D setup where prefill is not as important so we can save some memory. - + FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. This is the most performant mode for most models and is the default. Currently, the cudagraph mode is only used for the v1 engine. - Note that the cudagraph logic is generally orthogonal to the - compilation logic. While piecewise cudagraphs require piecewise + Note that the cudagraph logic is generally orthogonal to the + compilation logic. While piecewise cudagraphs require piecewise compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full cudagraphs are supported with and without compilation. - - Warning: This flag is new and subject to change in addition + + Warning: This flag is new and subject to change in addition more modes may be added. """ use_cudagraph: bool = True @@ -422,7 +422,7 @@ class CompilationConfig: cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. + internally managed buffer. Default is False. Note that this flag is only effective when cudagraph_mode is PIECEWISE. """ full_cuda_graph: bool | None = False @@ -451,7 +451,7 @@ class CompilationConfig: outside the partition functions. For a graph with N cudagraph-unsafe ops (e.g., Attention), there would be N+1 partitions. To mark an op as cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when - register the custom op. + register the custom op. This config supports both full cudagraph and piecewise cudagraph without compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper @@ -468,8 +468,8 @@ class CompilationConfig: max_cudagraph_capture_size: int | None = field(default=None) """The maximum cudagraph capture size. - - If cudagraph_capture_sizes is specified, this will be set to the largest + + If cudagraph_capture_sizes is specified, this will be set to the largest size in that list (or checked for consistency if specified). If cudagraph_capture_sizes is not specified, the list of sizes is generated automatically following the pattern: @@ -478,7 +478,7 @@ class CompilationConfig: range(256, max_cudagraph_capture_size + 1, 16)) If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2, - 512) by default. This voids OOM in tight memory scenarios with small + 512) by default. This voids OOM in tight memory scenarios with small max_num_seqs, and prevents capture of many large graphs (>512) that would greatly increase startup time with limited performance benefit. """ @@ -579,6 +579,27 @@ class CompilationConfig: __str__ = __repr__ + @field_validator("mode", mode="before") + @classmethod + def validate_mode_before(cls, value: Any) -> Any: + """ + Enable parsing the `mode` field from string mode names. + Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE, + DYNAMO_TRACE_ONCE, VLLM_COMPILE. + """ + if isinstance(value, str): + # Convert string mode name to integer value + mode_name = value.upper() + + if mode_name not in CompilationMode.__members__: + raise ValueError( + f"Invalid compilation mode: {value}. " + f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}" + ) + + return CompilationMode[mode_name] + return value + @field_validator("cudagraph_mode", mode="before") @classmethod def validate_cudagraph_mode_before(cls, value: Any) -> Any: @@ -904,7 +925,7 @@ class CompilationConfig: return self.mode == CompilationMode.VLLM_COMPILE # Inductor partition case - return self.backend == "inductor" and self.mode > CompilationMode.NONE + return self.backend == "inductor" and self.mode != CompilationMode.NONE def custom_op_log_check(self): """ diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0fca967d9083..df9a1fd08af6 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -422,16 +422,13 @@ class VllmConfig: self.compilation_config.mode = CompilationMode.VLLM_COMPILE else: self.compilation_config.mode = CompilationMode.NONE - else: - assert self.compilation_config.mode >= CompilationMode.NONE - assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE # If user does not set custom ops via none or all set it here based on # compilation mode and backend. if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if ( self.compilation_config.backend == "inductor" - and self.compilation_config.mode > CompilationMode.NONE + and self.compilation_config.mode != CompilationMode.NONE ): self.compilation_config.custom_ops.append("none") else: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 22fe2ae9280a..62717a7eacdf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -23,6 +23,7 @@ from vllm.config import ( StructuredOutputsConfig, is_init_field, ) +from vllm.config.compilation import CompilationMode from vllm.config.model import ( ConvertOption, HfOverrides, @@ -259,7 +260,9 @@ class LLM: if compilation_config is not None: if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig(mode=compilation_config) + compilation_config_instance = CompilationConfig( + mode=CompilationMode(compilation_config) + ) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( **{