From 96b9aa5aa076e64c68765232aec343e4d0006e2a Mon Sep 17 00:00:00 2001 From: Morrison Turnansky Date: Tue, 14 Oct 2025 22:51:16 -0400 Subject: [PATCH] [Frontend][torch.compile] CompilationConfig Overhaul (#20283): name change compilation level to compilation mode, deprecation compilation level (#26355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: morrison-turnansky Signed-off-by: Morrison Turnansky Co-authored-by: Luka Govedič --- docs/configuration/conserving_memory.md | 4 +- docs/design/cuda_graphs.md | 4 +- examples/offline_inference/data_parallel.py | 2 +- .../compile/piecewise/test_multiple_graphs.py | 10 +- tests/compile/piecewise/test_simple.py | 4 +- tests/compile/piecewise/test_toy_llama.py | 10 +- tests/compile/test_aot_compile.py | 4 +- tests/compile/test_async_tp.py | 3 +- tests/compile/test_basic_correctness.py | 30 +++-- tests/compile/test_config.py | 20 ++-- tests/compile/test_decorator.py | 10 +- tests/compile/test_full_graph.py | 29 ++--- tests/compile/test_fusion.py | 4 +- tests/compile/test_fusion_all_reduce.py | 4 +- tests/compile/test_fusion_attn.py | 4 +- tests/compile/test_noop_elimination.py | 6 +- tests/compile/test_wrapper.py | 4 +- tests/distributed/test_sequence_parallel.py | 3 +- tests/engine/test_arg_utils.py | 20 ++-- tests/tpu/test_custom_dispatcher.py | 6 +- tests/utils_/test_utils.py | 10 +- tests/v1/cudagraph/test_cudagraph_dispatch.py | 22 ++-- tests/v1/cudagraph/test_cudagraph_mode.py | 39 +++---- tests/v1/e2e/test_kv_sharing_fast_prefill.py | 6 +- vllm/compilation/backends.py | 4 +- vllm/compilation/compiler_interface.py | 2 +- vllm/compilation/counter.py | 4 +- vllm/compilation/decorators.py | 10 +- vllm/compilation/monitor.py | 6 +- vllm/compilation/wrapper.py | 8 +- vllm/config/__init__.py | 4 +- vllm/config/compilation.py | 106 ++++++++++++------ vllm/config/vllm.py | 50 ++++----- vllm/entrypoints/llm.py | 6 +- .../layers/quantization/utils/w8a8_utils.py | 4 +- vllm/platforms/cpu.py | 8 +- vllm/platforms/tpu.py | 11 +- vllm/platforms/xpu.py | 4 +- vllm/utils/__init__.py | 10 +- vllm/v1/cudagraph_dispatcher.py | 4 +- vllm/v1/spec_decode/eagle.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 15 +-- 42 files changed, 270 insertions(+), 248 deletions(-) diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 2b0654fa6d463..85906d23dee33 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -58,12 +58,12 @@ You can adjust `compilation_config` to achieve a better balance between inferenc ```python from vllm import LLM - from vllm.config import CompilationConfig, CompilationLevel + from vllm.config import CompilationConfig, CompilationMode llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, # By default, it goes up to max_num_seqs cudagraph_capture_sizes=[1, 2, 4, 8, 16], ), diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index 315746b0ef674..c6d71589be985 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -167,7 +167,7 @@ class AttentionCGSupport(enum.Enum): """NO CUDA Graphs support""" ``` -Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation level. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture]. +Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture]. The following table lists backends that support full CUDA Graphs at the time of writing. @@ -202,7 +202,7 @@ os.environ.setdefault("VLLM_LOGGING_LEVEL", "DEBUG") import vllm from vllm.config import CUDAGraphMode -compilation_config = {"level": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"} +compilation_config = {"mode": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"} model = vllm.LLM( model="meta-llama/Llama-3.1-8B-Instruct", dtype="auto", diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 0076d4d30ee8e..a3e671a0f4cca 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -95,7 +95,7 @@ def parse_args(): parser.add_argument( "--compilation-config", type=int, - help=("Compilation optimization (O) level 0-3."), + help=("Compilation optimization (O) mode 0-3."), ) parser.add_argument( "--quantization", diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 0d265bc596386..d1f741479acf4 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -14,7 +14,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, @@ -199,10 +199,10 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): outputs = [] - # piecewise compile + # vllmcompile compile vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], @@ -251,7 +251,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): # no compile or cudagraph vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.NO_COMPILATION, + mode=CompilationMode.NONE, ) ) cudagraph_runtime_mode = CUDAGraphMode.NONE @@ -280,7 +280,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): # piecewise compile without CUDA graph vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=False, splitting_ops=["silly::attention"], use_inductor_graph_partition=use_inductor_graph_partition, diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index bc65e3da0ae74..f61a0a4eb740d 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -13,7 +13,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, @@ -61,7 +61,7 @@ def _run_simple_model( ): vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, use_inductor=use_inductor, splitting_ops=splitting_ops, diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 7ab610fa78115..75a89d692fa8f 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -21,7 +21,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, @@ -356,13 +356,13 @@ def test_toy_llama( ) compile_config_no_compile = CompilationConfig( - level=CompilationLevel.NO_COMPILATION, + level=CompilationMode.NONE, cudagraph_mode=CUDAGraphMode.NONE, backend="eager", ) compile_config_no_split = CompilationConfig( - level=CompilationLevel.PIECEWISE, + level=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=use_inductor_graph_partition, cudagraph_mode=CUDAGraphMode.PIECEWISE, backend=backend, @@ -458,14 +458,14 @@ def benchmark(): for piecewise in [False, True]: if piecewise: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, cudagraph_capture_sizes=cudagraph_sizes, ) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 08f79d90cd367..1701d85fe84e7 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -10,7 +10,7 @@ import torch from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, VllmConfig, set_current_vllm_config, ) @@ -38,7 +38,7 @@ class CompiledMod(torch.nn.Module): def make_vllm_config() -> VllmConfig: return VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + level=CompilationMode.VLLM_COMPILE, ) ) diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 102a929bf2409..60856f5a58067 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AsyncTPPass from vllm.config import ( CompilationConfig, + CompilationMode, DeviceConfig, ModelConfig, PassConfig, @@ -400,7 +401,7 @@ def test_async_tp_pass_correctness( common_args.append("--enforce-eager") compilation_config = { - "level": 3, + "mode": CompilationMode.VLLM_COMPILE, "compile_sizes": [2, 4, 8], "splitting_ops": [], "pass_config": {"enable_async_tp": async_tp_enabled}, diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index ab6a17e149fcd..954774a8e3983 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -4,7 +4,7 @@ import dataclasses import pytest -from vllm.config import CompilationLevel +from vllm.config import CompilationMode from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings @@ -21,7 +21,7 @@ class TestSetting: # we cannot afford testing the full Cartesian product -# of all models and all levels +# of all models and all modes @pytest.mark.parametrize( "test_setting", [ @@ -121,15 +121,13 @@ def test_compile_correctness( all_args: list[list[str]] = [] all_envs: list[dict[str, str] | None] = [] - for comp_level in [ - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, - CompilationLevel.PIECEWISE, + for comp_mode in [ + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + CompilationMode.VLLM_COMPILE, ]: - for level in [CompilationLevel.NO_COMPILATION, comp_level]: - all_args.append( - final_args + [f"-O.level={level}", "-O.backend=inductor"] - ) + for mode in [CompilationMode.NONE, comp_mode]: + all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"]) # inductor will change the output, so we only compare if the output # is close, not exactly the same. @@ -142,13 +140,13 @@ def test_compile_correctness( all_envs.clear() all_args.clear() - for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, - CompilationLevel.PIECEWISE, + for mode in [ + CompilationMode.NONE, + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + CompilationMode.VLLM_COMPILE, ]: - all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"]) + all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"]) all_envs.append({}) all_envs.append({}) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index ae8b0b226c313..7f51c763da73c 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -4,7 +4,7 @@ import pytest from vllm.compilation.counter import compilation_counter from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.config.compilation import CompilationLevel +from vllm.config.compilation import CompilationMode from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer @@ -90,16 +90,16 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @pytest.mark.forked -def test_dynamo_as_is(vllm_runner, monkeypatch): +def test_stock_torch_compile(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(dynamo_as_is_count=1), + compilation_counter.expect(stock_torch_compile_count=1), # loading the model causes compilation (if enabled) to happen vllm_runner( "facebook/opt-125m", - compilation_config={"level": 1}, + compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE}, gpu_memory_utilization=0.4, ) as _, ): @@ -112,11 +112,11 @@ def test_no_compilation(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0), # loading the model causes compilation (if enabled) to happen vllm_runner( "facebook/opt-125m", - compilation_config={"level": 0}, + compilation_config={"mode": CompilationMode.NONE}, gpu_memory_utilization=0.4, ) as _, ): @@ -130,7 +130,7 @@ def test_enforce_eager(vllm_runner, monkeypatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0), # loading the model causes compilation (if enabled) to happen vllm_runner( "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4 @@ -151,7 +151,7 @@ def test_splitting_ops_dynamic(): if is_torch_equal_or_newer("2.9.0.dev"): config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + level=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, splitting_ops=["vllm::unified_attention"], ) @@ -163,7 +163,7 @@ def test_splitting_ops_dynamic(): # When attn_fusion pass enabled, splitting_ops now default to attention ops. config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + level=CompilationMode.VLLM_COMPILE, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, @@ -178,7 +178,7 @@ def test_splitting_ops_dynamic(): if is_torch_equal_or_newer("2.9.0.dev"): config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + level=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 63cb266094a12..4d60899a628a9 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -8,7 +8,7 @@ from vllm.compilation.decorators import ignore_torch_compile, support_torch_comp from vllm.config import ( CacheConfig, CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, @@ -66,10 +66,10 @@ def run_model( def test_ignore_torch_compile_decorator(): - # piecewise + # vllmcompile vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], @@ -185,7 +185,7 @@ def test_conditional_compile_enable_if(): kv_sharing_fast_prefill=True, ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], @@ -218,7 +218,7 @@ def test_conditional_compile_enable_if(): kv_sharing_fast_prefill=False, ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2f3794c90b204..2d290771f9ad7 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer @@ -80,22 +80,22 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): @pytest.mark.parametrize( - "optimization_level", - [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], + "compilation_mode", + [CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE], ) @pytest.mark.parametrize("model_info", models_list(all=True)) @create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, model_info: tuple[str, dict[str, Any]], - optimization_level: int, + compilation_mode: int, ): model, model_kwargs = model_info with monkeypatch.context(): print(f"MODEL={model}") - run_model(optimization_level, model, model_kwargs) + run_model(compilation_mode, model, model_kwargs) # TODO(luka) add other supported compilation config scenarios here @@ -104,7 +104,7 @@ def test_full_graph( [ # additional compile sizes, only some of the models ( - CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]), + CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]), model, ) for model in models_list(all=False) @@ -113,7 +113,7 @@ def test_full_graph( # RMSNorm + quant fusion, only 8-bit quant models ( CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm"], pass_config=PassConfig(enable_fusion=True, enable_noop=True), ), @@ -125,7 +125,8 @@ def test_full_graph( # Test depyf integration works ( CompilationConfig( - level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir() + mode=CompilationMode.VLLM_COMPILE, + debug_dump_path=tempfile.gettempdir(), ), ("facebook/opt-125m", {}), ), @@ -134,7 +135,7 @@ def test_full_graph( # graph inductor partition ( CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, # inductor graph partition uses # torch._C.Tag.cudagraph_unsafe to specify splitting ops use_inductor_graph_partition=True, @@ -164,10 +165,10 @@ def test_custom_compile_config( @pytest.mark.parametrize( - "optimization_level", - [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], + "compilation_mode", + [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) -def test_fp8_kv_scale_compile(optimization_level: int): +def test_fp8_kv_scale_compile(compilation_mode: int): model = "Qwen/Qwen2-0.5B" model_kwargs = { "quantization": "fp8", @@ -175,7 +176,7 @@ def test_fp8_kv_scale_compile(optimization_level: int): "calculate_kv_scales": True, "max_model_len": 512, } - run_model(optimization_level, model, model_kwargs) + run_model(compilation_mode, model, model_kwargs) def test_inductor_graph_partition_attn_fusion(caplog_vllm): @@ -184,7 +185,7 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm): model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, cudagraph_mode=CUDAGraphMode.PIECEWISE, custom_ops=["+quant_fp8"], diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7c22336432299..1a5eaf2639b36 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -13,7 +13,7 @@ from vllm.compilation.fusion import ( ) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -114,7 +114,7 @@ def test_fusion_rmsnorm_quant( vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"], pass_config=PassConfig(enable_fusion=True, enable_noop=True), ) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 455d1bb039057..fbcd6c71fb723 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -12,7 +12,7 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, DeviceConfig, ModelConfig, PassConfig, @@ -219,7 +219,7 @@ def all_reduce_fusion_pass_on_test_model( vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"] + mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"] ) ) vllm_config.compilation_config.pass_config = PassConfig( diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index d1ab85cfb875c..a8d78daa32a1d 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -19,7 +19,7 @@ from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( CacheConfig, CompilationConfig, - CompilationLevel, + CompilationMode, ModelConfig, PassConfig, SchedulerConfig, @@ -321,7 +321,7 @@ def test_attention_quant_pattern( ), scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, custom_ops=["+quant_fp8"], use_inductor_graph_partition=use_inductor_graph_partition, ), diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py index 188f4514dda5f..0ccc1a0161629 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/test_noop_elimination.py @@ -6,7 +6,7 @@ import torch import vllm from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig from .backend import TestBackend @@ -50,7 +50,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, pass_config=PassConfig(enable_noop=True), ) ) @@ -98,7 +98,7 @@ def test_non_noop_slice_preserved(): vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, pass_config=PassConfig(enable_noop=True), ) ) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index b2fff822bbbb5..da0afd9eaa49f 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -5,7 +5,7 @@ import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel +from vllm.config import CompilationMode class MyMod(torch.nn.Module): @@ -20,7 +20,7 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") super().__init__( - compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE + compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE ) def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index a431bf30fc890..362e9daf5ae04 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -15,6 +15,7 @@ from typing import Literal, NamedTuple import pytest +from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger @@ -234,7 +235,7 @@ def _compare_sp( common_args.append("--skip-tokenizer-init") compilation_config = { - "level": 3, + "mode": CompilationMode.VLLM_COMPILE, "custom_ops": ["+rms_norm"], "compile_sizes": [4, 8], "pass_config": { diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 78928a53942f9..c73083b0b5ef6 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -226,30 +226,30 @@ def test_compilation_config(): # set to O3 args = parser.parse_args(["-O0"]) - assert args.compilation_config.level == 0 + assert args.compilation_config.mode == 0 # set to O 3 (space) args = parser.parse_args(["-O", "1"]) - assert args.compilation_config.level == 1 + assert args.compilation_config.mode == 1 # set to O 3 (equals) args = parser.parse_args(["-O=2"]) - assert args.compilation_config.level == 2 + assert args.compilation_config.mode == 2 - # set to O.level 3 - args = parser.parse_args(["-O.level", "3"]) - assert args.compilation_config.level == 3 + # set to O.mode 3 + args = parser.parse_args(["-O.mode", "3"]) + assert args.compilation_config.mode == 3 # set to string form of a dict args = parser.parse_args( [ "-O", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' '"use_inductor": false}', ] ) assert ( - args.compilation_config.level == 3 + args.compilation_config.mode == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] and not args.compilation_config.use_inductor ) @@ -258,12 +258,12 @@ def test_compilation_config(): args = parser.parse_args( [ "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' '"use_inductor": true}', ] ) assert ( - args.compilation_config.level == 3 + args.compilation_config.mode == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] and args.compilation_config.use_inductor ) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 102e5ddf16d6d..cf455ff3edbd3 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -3,7 +3,7 @@ import pytest -from vllm.config import CompilationLevel +from vllm.config import CompilationMode from ..utils import compare_two_settings @@ -21,13 +21,13 @@ def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch): "--max-model-len=256", "--max-num-seqs=32", "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_ONCE}", + f"-O{CompilationMode.DYNAMO_TRACE_ONCE}", ], arg2=[ "--max-model-len=256", "--max-num-seqs=32", "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_AS_IS}", + f"-O{CompilationMode.STOCK_TORCH_COMPILE}", ], env1={}, env2={}, diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 308629ab05834..af5fc758f2c26 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -299,7 +299,7 @@ def test_dict_args(parser): "val2", "--hf-overrides.key2.key4", "val3", - # Test compile config and compilation level + # Test compile config and compilation mode "-O.use_inductor=true", "-O.backend", "custom", @@ -352,7 +352,7 @@ def test_dict_args(parser): }, } assert parsed_args.compilation_config == { - "level": 1, + "mode": 1, "use_inductor": True, "backend": "custom", "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], @@ -367,7 +367,7 @@ def test_duplicate_dict_args(caplog_vllm, parser): "--hf-overrides.key1", "val2", "-O1", - "-O.level", + "-O.mode", "2", "-O3", ] @@ -375,12 +375,12 @@ def test_duplicate_dict_args(caplog_vllm, parser): parsed_args = parser.parse_args(args) # Should be the last value assert parsed_args.hf_overrides == {"key1": "val2"} - assert parsed_args.compilation_config == {"level": 3} + assert parsed_args.compilation_config == {"mode": 3} assert len(caplog_vllm.records) == 1 assert "duplicate" in caplog_vllm.text assert "--hf-overrides.key1" in caplog_vllm.text - assert "-O.level" in caplog_vllm.text + assert "-O.mode" in caplog_vllm.text @pytest.mark.parametrize( diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 59841a446db3e..02fa27e3f05f7 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -11,7 +11,7 @@ from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, ParallelConfig, SchedulerConfig, @@ -42,7 +42,7 @@ def _create_vllm_config( mock_config.parallel_config = ParallelConfig() # Mimic the behavior of VllmConfig.__post_init__() - if compilation_config.level == CompilationLevel.PIECEWISE: + if compilation_config.mode == CompilationMode.VLLM_COMPILE: compilation_config.set_splitting_ops_for_v1() return mock_config @@ -50,23 +50,23 @@ def _create_vllm_config( class TestCudagraphDispatcher: @pytest.mark.parametrize( - "case_id,cudagraph_mode_str,compilation_level", + "case_id,cudagraph_mode_str,compilation_mode", [ # Test case 0: Full CG for mixed batches, no separate routine - (0, "FULL", CompilationLevel.NO_COMPILATION), + (0, "FULL", CompilationMode.NONE), # Test case 1: Full CG for uniform batches, piecewise for mixed - (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION), + (1, "FULL_AND_PIECEWISE", CompilationMode.NONE), # Test case 2: Full CG for uniform batches, no CG for mixed - (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION), - # Test case 3: Piecewise for all - (3, "PIECEWISE", CompilationLevel.PIECEWISE), + (2, "FULL_DECODE_ONLY", CompilationMode.NONE), + # Test case 3: PIECEWISE for all + (3, "PIECEWISE", CompilationMode.VLLM_COMPILE), ], ) - def test_dispatcher(self, cudagraph_mode_str, compilation_level): + def test_dispatcher(self, cudagraph_mode_str, compilation_mode): # Setup dispatcher comp_config = CompilationConfig( cudagraph_mode=cudagraph_mode_str, - level=compilation_level, + mode=compilation_mode, cudagraph_capture_sizes=[1, 8], ) @@ -242,7 +242,7 @@ class TestCudagraphIntegration: def setup_method(self): # only FULL mode for non-uniform batches self.comp_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, cudagraph_mode="FULL", cudagraph_capture_sizes=[10, 20], ) diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 8c8148ae20948..818ae1d7ba677 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -10,7 +10,7 @@ import pytest from tests.utils import wait_for_gpu_memory_to_clear from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, CompilationMode from vllm.platforms import current_platform @@ -73,7 +73,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte gpu_memory_utilization=0.45, max_model_len=1024, compilation_config=CompilationConfig( - level=3, cudagraph_mode=cudagraph_mode + mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode ), ) llm.generate(["Hello, my name is"] * 10) @@ -90,32 +90,27 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ) -# test cudagraph_mode with different compilation level. -# (backend_name, cudagraph_mode, compilation_level, supported) +# test cudagraph_mode with different compilation mode. +# (backend_name, cudagraph_mode, compilation_mode, supported) combo_cases_2 = [ - ("FA2", "FULL", 0, True), # no compilation + full cudagraph - ("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph - ("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph - ("FA2", "PIECEWISE", 3, True), # piecewise compilation + piecewise cudagraph - ( - "FA2", - "FULL_AND_PIECEWISE", - 0, - False, - ), # piecewise cudagraph not supported without piecewise compilation - ("FA2", "FULL_AND_PIECEWISE", 3, True), - ("FA2", "FULL_DECODE_ONLY", 0, True), - ("FA2", "FULL_DECODE_ONLY", 3, True), - ("FA2", "NONE", 0, True), # no compilation + no cudagraph - ("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph + ("FA2", "FULL", CompilationMode.NONE, True), + ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), + ("FA2", "PIECEWISE", CompilationMode.NONE, False), + ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + ("FA2", "NONE", CompilationMode.NONE, True), + ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), ] @pytest.mark.parametrize( - "backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2 + "backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2 ) def test_cudagraph_compilation_combo(combo_case): - backend_name, cudagraph_mode, compilation_level, supported = combo_case + backend_name, cudagraph_mode, compilation_mode, supported = combo_case env_vars = backend_configs[backend_name].env_vars @@ -130,7 +125,7 @@ def test_cudagraph_compilation_combo(combo_case): gpu_memory_utilization=0.45, max_model_len=1024, compilation_config=CompilationConfig( - level=compilation_level, cudagraph_mode=cudagraph_mode + mode=compilation_mode, cudagraph_mode=cudagraph_mode ), ) llm.generate(["Hello, my name is"] * 10) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index 89e5f26ac627f..f2c6d1c1fd1a4 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -7,7 +7,7 @@ import pytest import torch from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationMode from vllm.distributed import cleanup_dist_env_and_memory from ...utils import fork_new_process_for_each_test @@ -75,9 +75,9 @@ def test_kv_sharing_fast_prefill( # This allows vLLM compilation backend to handle allocating and # managing buffers for cudagraph cudagraph_copy_inputs=True, - level=CompilationLevel.PIECEWISE + mode=CompilationMode.VLLM_COMPILE if not enforce_eager - else CompilationLevel.NO_COMPILATION, + else CompilationMode.NONE, ) with monkeypatch.context() as m: diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 46c433fe6aefb..91be7e85af518 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -56,7 +56,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: return InductorAdaptor() else: assert compilation_config.backend == "eager", ( - "Custom backends not supported with CompilationLevel.PIECEWISE" + "Custom backends not supported with CompilationMode.VLLM_COMPILE" ) logger.debug("Using EagerAdaptor") @@ -481,7 +481,7 @@ def set_model_tag(tag: str): class VllmBackend: """The compilation backend for `torch.compile` with vLLM. - It is used for compilation level of `CompilationLevel.PIECEWISE`, + It is used for compilation mode of `CompilationMode.VLLM_COMPILE`, where we customize the compilation. The major work of this backend is to split the graph into diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 4553007027e39..e2369a635ad1f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -575,7 +575,7 @@ class InductorAdaptor(CompilerInterface): Because it is re-entrant, we always set it (even if entering via Dynamo and the context was already entered). We might want to revisit if it - should be set at a different level of compilation. + should be set at a different mode of compilation. This is likely a bug in PyTorch: public APIs should not rely on manually setting up internal contexts. But we also rely on non-public diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 9e8de831bcb29..20918099f169d 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -27,8 +27,8 @@ class CompilationCounter: num_cache_entries_updated: int = 0 # The number of standalone_compile compiled artifacts saved num_compiled_artifacts_saved: int = 0 - # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS - dynamo_as_is_count: int = 0 + # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE + stock_torch_compile_count: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index fe19d4e851294..20d4681e2c789 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -18,7 +18,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config +from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import resolve_obj_by_qualname, supports_dynamo @@ -233,11 +233,11 @@ def _support_torch_compile( old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config enable_compile = enable_if is None or enable_if(vllm_config) - # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = ( - vllm_config.compilation_config.level - in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS] + vllm_config.compilation_config.mode + in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE] or not supports_dynamo() or _should_ignore_torch_compile(self.__class__) or not enable_compile @@ -247,7 +247,7 @@ def _support_torch_compile( compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_level=vllm_config.compilation_config.level + self, compilation_mode=vllm_config.compilation_config.mode ) cls.__init__ = __init__ diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index d3c437795fabb..1e6d0e79228b0 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -3,7 +3,7 @@ import time -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import CompilationConfig, CompilationMode, VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -18,7 +18,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config path = vllm_config.compile_debug_dump_path() - if compilation_config.level == CompilationLevel.PIECEWISE and path: + if compilation_config.mode == CompilationMode.VLLM_COMPILE and path: import depyf path.mkdir(parents=True, exist_ok=True) @@ -29,7 +29,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): def end_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.PIECEWISE: + if compilation_config.mode == CompilationMode.VLLM_COMPILE: logger.info( "torch.compile takes %.2f s in total", compilation_config.compilation_time ) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b4a0d89af0d6d..4b10c85209f63 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,7 +11,7 @@ from types import CodeType import torch import vllm.envs as envs -from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config +from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config from vllm.logger import init_logger logger = init_logger(__name__) @@ -31,7 +31,7 @@ class TorchCompileWrapperWithCustomDispatcher: """ def __init__( - self, compiled_callable: Callable | None = None, compilation_level: int = 0 + self, compiled_callable: Callable | None = None, compilation_mode: int = 0 ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config @@ -72,7 +72,7 @@ class TorchCompileWrapperWithCustomDispatcher: # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = ( - compilation_level >= CompilationLevel.DYNAMO_ONCE + compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE ) def aot_compile(self, *args, **kwargs): @@ -85,7 +85,7 @@ class TorchCompileWrapperWithCustomDispatcher: return self.compiled_callable.aot_compile((args, kwargs)) def __call__(self, *args, **kwargs): - """Implement the dispatch logic here, beyond the torch.compile level. + """Implement the dispatch logic here, beyond the torch.compile mode. NOTE: this function can have additional arguments beyond the forward method, for directly dispatching to the compiled code. """ diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 6a0197d044dcd..7f1cc52024205 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -4,7 +4,7 @@ from vllm.config.cache import CacheConfig from vllm.config.compilation import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, PassConfig, ) @@ -49,7 +49,7 @@ __all__ = [ "CacheConfig", # From vllm.config.compilation "CompilationConfig", - "CompilationLevel", + "CompilationMode", "CUDAGraphMode", "PassConfig", # From vllm.config.device diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index fb80835ba48a1..a34fb0bf920c0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -26,12 +26,20 @@ else: logger = init_logger(__name__) -class CompilationLevel: - # constants for the levels of the compilation process - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 +class CompilationMode: + """The compilation approach used for torch.compile-based compilation of the + model.""" + + NONE = 0 + """No torch.compile compilation is applied, model runs in fully eager pytorch mode. + The model runs as-is.""" + STOCK_TORCH_COMPILE = 1 + """The standard `torch.compile` compilation pipeline.""" + DYNAMO_TRACE_ONCE = 2 + """Single Dynamo trace through the model, avoiding recompilation.""" + VLLM_COMPILE = 3 + """Custom vLLM Inductor-based backend with caching, piecewise compilation, + shape specialization, and custom passes.""" class CUDAGraphMode(enum.Enum): @@ -134,7 +142,7 @@ class CompilationConfig: """Configuration for compilation. It has three parts: - Top-level Compilation control: - - [`level`][vllm.config.CompilationConfig.level] + - [`mode`][vllm.config.CompilationConfig.mode] - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] - [`backend`][vllm.config.CompilationConfig.backend] @@ -171,14 +179,26 @@ class CompilationConfig: # Top-level Compilation control level: int | None = None - """The level of compilation: + """ + Level is deprecated and will be removed in the next release, + either 0.12.0 or 0.11.2 whichever is soonest. + Please use mode. Currently all levels are mapped to mode. + """ + # Top-level Compilation control + mode: int | None = None + """The compilation approach used for torch.compile-based compilation of the + model. - - None: If None, we will select the default compilation level. - For V1 engine this is 3, for V0 engine this is 0. - - 0: no compilation. - - 1: dynamo as is. - - 2: dynamo once. - - 3: piecewise compilation.""" + - None: If None, we will select the default compilation mode. + For V1 engine this is 3. + - 0: NONE: No torch.compile compilation is applied, model runs in fully + eager pytorch mode. The model runs as-is. + - 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline. + - 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding + recompilation by removing guards. + Requires no dynamic-shape-dependent control-flow. + - 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching, + piecewise compilation, shape specialization, and custom passes.""" debug_dump_path: Path | None = None """The path to dump the debug information.""" cache_dir: str = "" @@ -195,11 +215,11 @@ class CompilationConfig: backend function. We use string to avoid serialization issues when using compilation in a - distributed setting. When the compilation level is 1 or 2, the backend is + distributed setting. When the compilation mode is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the - compilation level is 3, the backend is used for the piecewise compilation + compilation mode is 3, the backend is used for the piecewise compilation (it sees a part of the graph). The backend can not be custom for compilation - level 3, i.e. the backend must be either eager or inductor. Furthermore, + mode 3, i.e. the backend must be either eager or inductor. Furthermore, compilation is only piecewise if splitting ops is set accordingly and use_inductor_graph_partition is off. Note that the default options for splitting ops are sufficient for piecewise compilation. @@ -214,7 +234,7 @@ class CompilationConfig: - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and - disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. + disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" splitting_ops: list[str] | None = None """A list of ops to exclude from cudagraphs, used in piecewise compilation. @@ -249,7 +269,7 @@ class CompilationConfig: One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - This setting is ignored if level