diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 7372dc99bc799..d88645e3bfd62 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], ) ) @@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=False, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 5d65df67f5a6e..bc65e3da0ae74 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -127,7 +127,7 @@ def _run_simple_model( @torch.inference_mode() def test_simple_piecewise_compile(use_inductor): _run_simple_model( - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], use_inductor_graph_partition=False, use_inductor=use_inductor, # 2 * num_layers + 1 @@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor): @torch.inference_mode() -@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []]) +@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []]) def test_simple_inductor_graph_partition(splitting_ops, monkeypatch): if not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index e053367fb3d78..08f59283a6db5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -268,7 +268,7 @@ def run_model( cudagraph_capture_sizes=[1, 2], ) if split_attn: - compilation_config.splitting_ops = ["silly.attention"] + compilation_config.splitting_ops = ["silly::attention"] cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: compilation_config = CompilationConfig( @@ -438,7 +438,7 @@ def benchmark(): compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=cudagraph_sizes, ) else: diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 0da7f58a2f5f7..ae8b0b226c313 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -4,10 +4,12 @@ import pytest from vllm.compilation.counter import compilation_counter from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.utils import _is_torch_equal_or_newer +from vllm.config.compilation import CompilationLevel +from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): + # Test the version comparison logic using the private function assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev") @@ -17,6 +19,9 @@ def test_version(): def test_use_cudagraphs_dynamic(): vllm_config = VllmConfig() + # Default V1 configuration now starts without cudagraphs enabled; the + # engine decides when to capture based on runtime settings instead of a + # blanket default. assert vllm_config.compilation_config.use_cudagraph @@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch): def test_splitting_ops_dynamic(): # Default config config = VllmConfig() - assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE - assert config.compilation_config.splitting_ops_contain_attention() + # Default V1 config leaves cudagraph mode unset; splitting ops are only + # populated when the engine decides to use piecewise compilation. + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert not config.compilation_config.splitting_ops_contain_attention() # When use_inductor_graph_partition=True - if _is_torch_equal_or_newer("2.9.0.dev"): - # inductor graph partition is only available in PyTorch 2.9+. - # this is a fast config check so we are not using pytest.skip. + if is_torch_equal_or_newer("2.9.0.dev"): config = VllmConfig( compilation_config=CompilationConfig( - use_inductor_graph_partition=True, splitting_ops=["silly_attention"] + level=CompilationLevel.PIECEWISE, + use_inductor_graph_partition=True, + splitting_ops=["vllm::unified_attention"], ) ) - # should ignore splitting_ops - assert config.compilation_config.splitting_ops == [] + # with inductor partition we use splitting_ops directly for + # partition rules + assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] - # When attn_fusion pass enabled. + # When attn_fusion pass enabled, splitting_ops now default to attention ops. config = VllmConfig( compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) ) - assert config.compilation_config.splitting_ops == [] - # cudagraph mode also fall back to FULL - assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL - - # splitting_ops can not contain attention ops when attn_fusion - # pass enabled. - with pytest.raises(AssertionError): - config = VllmConfig( - compilation_config=CompilationConfig( - pass_config={"enable_attn_fusion": True, "enable_noop": True}, - custom_ops=["+quant_fp8"], - cudagraph_mode=CUDAGraphMode.PIECEWISE, - # work around for accessing all attntion ops - splitting_ops=CompilationConfig()._attention_ops, - ) - ) + # With the new simplified logic, attention fusion works with splitting_ops + assert config.compilation_config.splitting_ops_contain_attention() + # cudagraph mode remains PIECEWISE + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE # When both use_inductor_graph_partition and attn_fusion pass enabled. - if _is_torch_equal_or_newer("2.9.0.dev"): + if is_torch_equal_or_newer("2.9.0.dev"): config = VllmConfig( compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, use_inductor_graph_partition=True, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) ) - assert config.compilation_config.splitting_ops == [] - # enable_attn_fusion is directly support under + # With inductor graph partition, attn_fusion and splitting_ops + # work together. Default splitting_ops include attention ops. + assert config.compilation_config.splitting_ops_contain_attention() + # enable_attn_fusion is directly supported under # use_inductor_graph_partition=True, and cudagraph_mode # is unchanged. assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + + +def test_resolve_operator_overload(): + import torch + + from vllm.compilation.partition_rules import resolve_defined_ops + + # Test valid operator names + resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"]) + assert len(resolved) == 2 + assert resolved[0] is torch.ops.aten.mm.default + assert resolved[1] is torch.ops.aten.addmm.default + + # Test that invalid operators are skipped (not raising exceptions) + resolved = resolve_defined_ops( + [ + "aten::mm.default", + "aten::nonexistent_op.default", # This should be skipped + "aten::addmm.default", + ] + ) + assert len(resolved) == 2 # Only 2 valid ops + assert resolved[0] is torch.ops.aten.mm.default + assert resolved[1] is torch.ops.aten.addmm.default diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index d7048821bb606..6b050207ec41b 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator(): compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], ) ) @@ -186,7 +186,7 @@ def test_conditional_compile_enable_if(): compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], ), ) @@ -218,7 +218,7 @@ def test_conditional_compile_enable_if(): compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], ), ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index da9debbb0e275..c35d77d4668cb 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,6 +15,11 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs +from vllm.compilation.inductor_pass import pass_context +from vllm.compilation.partition_rules import ( + inductor_partition_rule_context, + resolve_defined_ops, +) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -76,6 +81,21 @@ class CompilerManager: def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) + @contextmanager + def compile_context(self, runtime_shape: Optional[int] = None): + """Provide compilation context for the duration of compilation to set + any torch global properties we want to scope to a single Inductor + compilation (e.g. partition rules, pass context).""" + with pass_context(runtime_shape): + if self.compilation_config.use_inductor_graph_partition: + inductor_partition_ops = resolve_defined_ops( + self.compilation_config.splitting_ops + ) + with inductor_partition_rule_context(inductor_partition_ops): + yield + else: + yield + def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" ): @@ -197,9 +217,15 @@ class CompilerManager: maybe_key = None else: maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key - ) + + with self.compile_context(runtime_shape): + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + runtime_shape, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" @@ -258,7 +284,7 @@ class SplitItem: def split_graph( - graph: fx.GraphModule, ops: list[str] + graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload] ) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 @@ -267,7 +293,12 @@ def split_graph( for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == "call_function" and str(node.target) in ops: + # Match node.target against resolved_ops + # node.target can be OpOverloadPacket, need to check .default + if node.op == "call_function" and ( + node.target in resolved_ops + or (hasattr(node.target, "default") and node.target.default in resolved_ops) + ): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) @@ -615,9 +646,14 @@ class VllmBackend: self.graph = graph self.configure_post_pass() - self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_config.splitting_ops - ) + if self.compilation_config.use_inductor_graph_partition: + # Let Inductor decide partitioning; avoid FX-level pre-splitting. + fx_split_ops: list[str] = [] + else: + fx_split_ops = self.compilation_config.splitting_ops or [] + + resolved_split_ops = resolve_defined_ops(fx_split_ops) + self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops) from torch._dynamo.utils import lazy_format_graph_code diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index d1774489320f9..4b1893887ac84 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -17,8 +17,6 @@ from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer -from .inductor_pass import pass_context - class CompilerInterface: """ @@ -209,13 +207,12 @@ class InductorStandaloneAdaptor(CompilerInterface): from torch._inductor import standalone_compile - with pass_context(runtime_shape): - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None @@ -462,13 +459,12 @@ class InductorAdaptor(CompilerInterface): torch._functorch.config.patch(enable_remote_autograd_cache=False) ) - with pass_context(runtime_shape): - compiled_graph = compile_fx( - graph, - example_inputs, - inner_compile=hijacked_compile_fx_inner, - config_patches=current_config, - ) + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config, + ) # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # compilation cache. So turn off the checks if we disable the diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py new file mode 100644 index 0000000000000..c17a5bd4480c9 --- /dev/null +++ b/vllm/compilation/partition_rules.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextlib +from typing import TYPE_CHECKING + +from torch._library.utils import lookup_op + +from vllm.logger import init_logger + +if TYPE_CHECKING: + import torch + +logger = init_logger(__name__) + + +def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]: + """Resolve operator names to OpOverload objects. + + Skips operators that fail to resolve (e.g., operators not registered or + model-specific operators not present in the current model). + + Note: Users should inspect the operator graph before lowering and ensure + the specified operators are present in the final graph. Built-in PyTorch + operators (aten::*, torch::*) may be decomposed, fused, or transformed + during Inductor's compilation passes, so use them with caution. + + Args: + op_names: List of operator names in PyTorch format + (e.g., "vllm::unified_attention") + + Returns: + List of successfully resolved operator overloads + """ + resolved = [] + for op_name in op_names: + try: + resolved.append(lookup_op(op_name)) + except Exception: + # Skip operators that don't exist (e.g., model-specific ops) + logger.warning( + "Failed to resolve operator for Inductor partition: %s", op_name + ) + continue + + return resolved + + +@contextlib.contextmanager +def inductor_partition_rule_context(overloads: list[torch._ops.OpOverload]): + """Context manager to temporarily register Inductor partition rules. + + Registers custom partition rules for specified operators, forcing the + Inductor scheduler to partition the graph at these operators. The rules + are automatically restored to their previous state on exit. + + Note: Callers should use resolve_defined_ops() to convert operator names + to OpOverload objects before calling this function. + + Args: + overloads: List of resolved operator overload objects. + """ + if not overloads: + logger.debug("No partition ops provided; skipping rule registration.") + yield + return + + from torch._inductor.scheduler import ( # type: ignore + _custom_should_partition_fns, + register_should_partition_rule, + ) + + def _always_partition(*_args, **_kwargs): + return True + + # Save current state before registering + saved_rules = _custom_should_partition_fns.copy() + + for overload in overloads: + register_should_partition_rule( + overload, + _always_partition, + ) + + logger.debug("Registered inductor partition rules for %d operators", len(overloads)) + + try: + yield + finally: + # Clear and restore previous state + _custom_should_partition_fns.clear() + _custom_should_partition_fns.update(saved_rules) + logger.debug("Restored previous partition rules state.") diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 8046252c0b866..e65728ba7f4e1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -209,8 +209,23 @@ class CompilationConfig: disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" splitting_ops: Optional[list[str]] = None - """A list of ops to split the full graph into subgraphs, used in piecewise - compilation.""" + """A list of ops to exclude from cudagraphs, used in piecewise compilation. + + The behavior depends on use_inductor_graph_partition: + + - When use_inductor_graph_partition=False (default): + These ops are used for Dynamo FX-level graph splitting. The graph is + split at these ops before Inductor compilation, creating separate + subgraphs for cudagraph capture. + + - When use_inductor_graph_partition=True: + These ops are used to register Inductor partition rules. The graph + partitioning happens at Inductor codegen time after all passes and + fusions are finished, allowing compilation and custom passes to operate + on the full graph while still excluding these ops from cudagraphs. + + If None, defaults to attention ops for piecewise cudagraphs. + If empty list [], no ops are excluded (suitable for full cudagraphs).""" # Inductor capture use_inductor: bool = True @@ -367,18 +382,19 @@ class CompilationConfig: model code, e.g., Attention, FusedMOE when dp_size>1.""" # Attention ops; used for piecewise cudagraphs + # Use PyTorch operator format: "namespace::name" _attention_ops: ClassVar[list[str]] = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - "vllm.unified_mla_attention", - "vllm.unified_mla_attention_with_output", - "vllm.mamba_mixer2", - "vllm.mamba_mixer", - "vllm.short_conv", - "vllm.linear_attention", - "vllm.plamo2_mamba_mixer", - "vllm.gdn_attention", - "vllm.sparse_attn_indexer", + "vllm::unified_attention", + "vllm::unified_attention_with_output", + "vllm::unified_mla_attention", + "vllm::unified_mla_attention_with_output", + "vllm::mamba_mixer2", + "vllm::mamba_mixer", + "vllm::short_conv", + "vllm::linear_attention", + "vllm::plamo2_mamba_mixer", + "vllm::gdn_attention", + "vllm::sparse_attn_indexer", ] def compute_hash(self) -> str: @@ -654,31 +670,25 @@ class CompilationConfig: def set_splitting_ops_for_inductor_graph_partition(self): assert self.use_inductor_graph_partition - use_inductor_graph_partition_msg = ( - "When use_inductor_graph_partition=True, splitting_ops " - "are ignored and set to an empty list. Instead, " - '"tags=(torch._C.Tag.cudagraph_unsafe, )," is ' - "used to annotate custom ops for graph partition." - ) - if self.splitting_ops is not None and len(self.splitting_ops) > 0: - logger.warning_once(use_inductor_graph_partition_msg) - self.splitting_ops = [] + if self.splitting_ops is None: + self.splitting_ops = list(self._attention_ops) def set_splitting_ops_for_attn_fusion(self): assert self.pass_config.enable_attn_fusion - if self.splitting_ops is None: - self.splitting_ops = [] - if self.cudagraph_mode.has_piecewise_cudagraphs(): - logger.warning_once( - "enable_attn_fusion is incompatible with piecewise " - "cudagraph when use_inductor_graph_partition is off." - "In this case, splitting_ops will be set to empty " - "list, and cudagraph_mode will be set to FULL. " - "Please ensure you are using attention backends that " - "support cudagraph or set cudagraph_mode to NONE " - "explicitly if encountering any problems." - ) - self.cudagraph_mode = CUDAGraphMode.FULL + # For dynamo-partition (non-inductor) attention fusion, + # set splitting_ops to empty to avoid splitting at attention ops + self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "enable_attn_fusion is incompatible with piecewise " + "cudagraph when use_inductor_graph_partition is off. " + "In this case, splitting_ops will be set to empty " + "list, and cudagraph_mode will be set to FULL. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems." + ) + self.cudagraph_mode = CUDAGraphMode.FULL assert not self.splitting_ops_contain_attention(), ( "attention ops should not be in splitting_ops " @@ -691,23 +701,17 @@ class CompilationConfig: ) def is_attention_compiled_piecewise(self) -> bool: - use_fx_graph_piecewise_compilation = ( - self.level == CompilationLevel.PIECEWISE - and self.splitting_ops_contain_attention() - ) + if not self.splitting_ops_contain_attention(): + return False - inductor_used = ( - self.level == CompilationLevel.PIECEWISE and self.use_inductor - ) or ( - self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor" - ) - use_inductor_piecewise_compilation = ( - inductor_used - and self.use_inductor_graph_partition - and not self.splitting_ops_contain_attention() - ) + if not self.use_inductor_graph_partition: + # Dynamo-level FX split case + return self.level == CompilationLevel.PIECEWISE - return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation + # Inductor partition case + return ( + self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor" + ) def custom_op_log_check(self): """