From 0c824fc46fe6f19723ba7a930581951d1c2b35f2 Mon Sep 17 00:00:00 2001 From: Morrison Turnansky Date: Tue, 7 Oct 2025 15:53:43 -0400 Subject: [PATCH] [Frontend] CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops (#26113) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: morrison-turnansky Signed-off-by: Morrison Turnansky Signed-off-by: Luka Govedič Co-authored-by: Luka Govedič Co-authored-by: Jiangyun Zhu --- tests/compile/piecewise/test_toy_llama.py | 31 +++----- .../model_executor/test_enabled_custom_ops.py | 40 ++++++----- vllm/compilation/backends.py | 6 +- vllm/config/compilation.py | 71 ++++++++++++++++--- vllm/config/vllm.py | 19 +++++ vllm/model_executor/custom_op.py | 20 +++--- vllm/platforms/cpu.py | 2 - 7 files changed, 126 insertions(+), 63 deletions(-) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index e053367fb3d7..c3aff8ddad49 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -258,13 +258,13 @@ def tractable_computation( @torch.inference_mode def run_model( - llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False + llama_config, use_compile: bool, backend: str, split_attn: bool = False ) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - use_inductor=use_inductor, + backend=backend, cudagraph_capture_sizes=[1, 2], ) if split_attn: @@ -338,8 +338,8 @@ def run_model( return output.cpu() -@pytest.mark.parametrize("use_inductor", [True, False]) -def test_toy_llama(use_inductor: bool): +@pytest.mark.parametrize("backend", ["inductor", "eager"]) +def test_toy_llama(backend: str): # compare output with and without piecewise compilation llama_config = LlamaConfig( @@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(llama_config, use_inductor=False, use_compile=False)) - run_model(tractable_config, use_inductor=False, use_compile=False) + outputs.append(run_model(llama_config, backend="eager", use_compile=False)) + run_model(tractable_config, backend="eager", use_compile=False) - if use_inductor: + if backend == "inductor": kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} else: kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} @@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool): num_cudagraph_captured=2, **kwargs, ): - outputs.append( - run_model(llama_config, use_inductor=use_inductor, use_compile=True) - ) - run_model(tractable_config, use_inductor=use_inductor, use_compile=True) + outputs.append(run_model(llama_config, backend=backend, use_compile=True)) + run_model(tractable_config, backend=backend, use_compile=True) with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool): ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model( - llama_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True, - ) + run_model(llama_config, backend=backend, use_compile=True, split_attn=True) ) - run_model( - tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True - ) + run_model(tractable_config, backend=backend, use_compile=True, split_attn=True) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 12aad4cb8da0..ab3a3a8268a3 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -37,57 +37,59 @@ class Relu3(ReLUSquaredActivation): @pytest.mark.parametrize( - "env, torch_level, use_inductor, ops_enabled, default_on", + "env, torch_level, backend, ops_enabled, default_on", [ # Default values based on compile level # - All by default (no Inductor compilation) - (None, 0, False, [True] * 4, True), - (None, 1, True, [True] * 4, True), - (None, 2, False, [True] * 4, True), + (None, 0, "eager", [True] * 4, True), + (None, 1, "eager", [True] * 4, True), + (None, 2, "eager", [True] * 4, True), + (None, 3, "eager", [True] * 4, True), # - None by default (with Inductor) - (None, 3, True, [False] * 4, False), - (None, 4, True, [False] * 4, False), - # - All by default (without Inductor) - (None, 3, False, [True] * 4, True), - (None, 4, False, [True] * 4, True), + (None, 0, "inductor", [True] * 4, True), + # - None by default (with Inductor) + (None, 1, "inductor", [False] * 4, False), + (None, 2, "inductor", [False] * 4, False), + (None, 3, "inductor", [False] * 4, False), # Explicitly enabling/disabling # # Default: all # # All but SiluAndMul - ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True), + ("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True), # Only ReLU3 - ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False), + ("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False), # All but SiluAndMul - ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), + ("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True), # RMSNorm and SiluAndMul - ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), + ("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False), # All but RMSNorm - ("-rms_norm", 3, False, [0, 1, 1, 1], True), + ("-rms_norm", 3, "eager", [0, 1, 1, 1], True), # # Default: none # # Only ReLU3 - ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), + ("none,+relu3", 3, "inductor", [0, 0, 0, 1], False), # All but RMSNorm - ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), + ("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True), ], ) def test_enabled_ops( env: Optional[str], torch_level: int, - use_inductor: bool, + backend: str, ops_enabled: list[int], default_on: bool, ): custom_ops = env.split(",") if env else [] vllm_config = VllmConfig( compilation_config=CompilationConfig( - use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops + backend=backend, level=torch_level, custom_ops=custom_ops ) ) + # breakpoint() with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index da9debbb0e27..55bd3d0c60b1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -34,7 +34,7 @@ logger = init_logger(__name__) def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: - if compilation_config.use_inductor: + if compilation_config.backend == "inductor": # Use standalone compile only if requested, version is new enough, # and the symbol actually exists in this PyTorch build. if ( @@ -48,6 +48,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: logger.debug("Using InductorAdaptor") return InductorAdaptor() else: + assert compilation_config.backend == "eager", ( + "Custom backends not supported with CompilationLevel.PIECEWISE" + ) + logger.debug("Using EagerAdaptor") return EagerAdaptor() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3443d2e1559e..9346bfa6307a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -180,10 +180,11 @@ class CompilationConfig: """The directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use model-related information to generate a cache directory.""" - backend: str = "" + backend: str = "inductor" """The backend for compilation. It needs to be a string: - - "" (empty string): use the default backend. + - "" (empty string): use the default backend ("inductor" on CUDA-alike + platforms). - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - "full.module.name": a qualified name which can be used to import the @@ -192,7 +193,11 @@ class CompilationConfig: distributed setting. When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation level is 3, the backend is used for the piecewise compilation - (it sees a part of the graph).""" + (it sees a part of the graph). The backend can not be custom for compilation + level 3. Furthermore, compilation is only piecewise if splitting ops is set + accordingly and use_inductor_cudagraphs_partition is off. Note that the + default options for splitting ops are sufficient for piecewise compilation. + """ custom_ops: list[str] = field(default_factory=list) """Fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op @@ -210,8 +215,12 @@ class CompilationConfig: compilation.""" # Inductor capture - use_inductor: bool = True - """Whether to use inductor compilation: + use_inductor: Optional[bool] = None + """ + Whether to use inductor compilation. + + This flag is deprecated and will be removed. + Please use the 'backend' option instead. - False: inductor compilation is not used. graph runs in eager (custom_ops enabled by default). @@ -219,7 +228,11 @@ class CompilationConfig: One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - This setting is ignored if level Union[str, Callable]: + """ + Initialize the backend for the compilation config from a vllm config. + Arguments: + vllm_config: The vllm config to initialize the backend from. + Returns: + The backend for the compilation config. + """ + if self.level is None: + raise ValueError( + "No compilation level is set. This method should only be \ + called via vllm config where the level is set if none is \ + provided." + ) if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -531,15 +580,15 @@ class CompilationConfig: torch_backends = list_backends(exclude_tags=tuple()) if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: - if self.backend == "": - return "eager" if self.backend in torch_backends: return self.backend return resolve_obj_by_qualname(self.backend) - # TODO: pass user-specified backend to piecewise compilation - # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE + if self.backend not in ["eager", "inductor"]: + raise ValueError( + f"Invalid backend for piecewise compilation: {self.backend}" + ) from vllm.compilation.backends import VllmBackend @@ -692,7 +741,7 @@ class CompilationConfig: ) inductor_used = ( - self.level == CompilationLevel.PIECEWISE and self.use_inductor + self.level == CompilationLevel.PIECEWISE and self.backend == "inductor" ) or ( self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor" ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b5856958ce2e..37b8c3fe6677 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -318,6 +318,25 @@ class VllmConfig: # NB: Passing both --enforce-eager and a compilation level # in V0 means the compilation level wins out. self.compilation_config.level = CompilationLevel.NO_COMPILATION + else: + assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION + assert self.compilation_config.level <= CompilationLevel.PIECEWISE + assert self.compilation_config.level <= 3 + + # If user does not set custom ops via none or all set it here based on + # compilation level and backend. + if ( + self.compilation_config.custom_ops.count("none") + + self.compilation_config.custom_ops.count("all") + == 0 + ): + if ( + self.compilation_config.level > 0 + and self.compilation_config.backend != "eager" + ): + self.compilation_config.custom_ops.append("none") + else: + self.compilation_config.custom_ops.append("all") # async tp is built on top of sequence parallelism # and requires it to be enabled. diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index ad5a09ca970d..6a0ea266378a 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -114,7 +114,9 @@ class CustomOp(nn.Module): custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): logger.warning_once( - "Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501 + "Custom op %s was not registered, which means it won't appear\ + in the op registry. It will be enabled/disabled based on the\ + global settings.", # noqa: E501 cls.__name__, ) return CustomOp.default_on() @@ -128,19 +130,17 @@ class CustomOp(nn.Module): @staticmethod def default_on() -> bool: """ - On by default if PyTorch Inductor is not used. - Specifying 'all' or 'none' in custom_op takes precedence. + Behavior controlled by `CompilationConfig.custom_ops`: On by default if + 'all', off by default if 'none'. + When PyTorch Inductor is used, 'none' is the default value, + otherwise 'all'. """ - from vllm.config import CompilationLevel - compilation_config = get_cached_compilation_config() - default_on = ( - compilation_config.level < CompilationLevel.PIECEWISE - or not compilation_config.use_inductor - ) count_none = compilation_config.custom_ops.count("none") count_all = compilation_config.custom_ops.count("all") - return default_on and not count_none > 0 or count_all > 0 + assert count_none + count_all == 1 + + return not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). # To check if an op with a name is enabled, call .enabled() on the class. diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2f87664003dc..24e08a8ecbd7 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -274,8 +274,6 @@ class CpuPlatform(Platform): "epilogue_fusion": True, } ) - if compilation_config.use_inductor: - compilation_config.custom_ops = ["none"] if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION