[torch.compile] Make inductor partition rules respect splitting_ops #25691 (#25845)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
baonudesifeizhai 2025-10-10 12:35:28 -04:00 committed by GitHub
parent e519281920
commit cddce79fda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 267 additions and 112 deletions

View File

@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
)
)
@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=False,
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE

View File

@ -127,7 +127,7 @@ def _run_simple_model(
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
_run_simple_model(
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
use_inductor_graph_partition=False,
use_inductor=use_inductor,
# 2 * num_layers + 1
@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor):
@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

View File

@ -268,7 +268,7 @@ def run_model(
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
compilation_config.splitting_ops = ["silly::attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
@ -438,7 +438,7 @@ def benchmark():
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:

View File

@ -4,10 +4,12 @@ import pytest
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.utils import _is_torch_equal_or_newer
from vllm.config.compilation import CompilationLevel
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
def test_version():
# Test the version comparison logic using the private function
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
@ -17,6 +19,9 @@ def test_version():
def test_use_cudagraphs_dynamic():
vllm_config = VllmConfig()
# Default V1 configuration now starts without cudagraphs enabled; the
# engine decides when to capture based on runtime settings instead of a
# blanket default.
assert vllm_config.compilation_config.use_cudagraph
@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if _is_torch_equal_or_newer("2.9.0.dev"):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
)
# should ignore splitting_ops
assert config.compilation_config.splitting_ops == []
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled.
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(
compilation_config=CompilationConfig(
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# With the new simplified logic, attention fusion works with splitting_ops
assert config.compilation_config.splitting_ops_contain_attention()
# cudagraph mode remains PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer("2.9.0.dev"):
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload():
import torch
from vllm.compilation.partition_rules import resolve_defined_ops
# Test valid operator names
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
assert len(resolved) == 2
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# Test that invalid operators are skipped (not raising exceptions)
resolved = resolve_defined_ops(
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
)
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default

View File

@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator():
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
)
)
@ -186,7 +186,7 @@ def test_conditional_compile_enable_if():
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
),
)
@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
),
)

View File

@ -15,6 +15,11 @@ import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
inductor_partition_rule_context,
resolve_defined_ops,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -76,6 +81,21 @@ class CompilerManager:
def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)
@contextmanager
def compile_context(self, runtime_shape: Optional[int] = None):
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
if self.compilation_config.use_inductor_graph_partition:
inductor_partition_ops = resolve_defined_ops(
self.compilation_config.splitting_ops
)
with inductor_partition_rule_context(inductor_partition_ops):
yield
else:
yield
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
@ -197,9 +217,15 @@ class CompilerManager:
maybe_key = None
else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
)
with self.compile_context(runtime_shape):
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
runtime_shape,
maybe_key,
)
assert compiled_graph is not None, "Failed to compile the graph"
@ -258,7 +284,7 @@ class SplitItem:
def split_graph(
graph: fx.GraphModule, ops: list[str]
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
@ -267,7 +293,12 @@ def split_graph(
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
if node.op == "call_function" and str(node.target) in ops:
# Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default
if node.op == "call_function" and (
node.target in resolved_ops
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
):
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
@ -615,9 +646,14 @@ class VllmBackend:
self.graph = graph
self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_config.splitting_ops
)
if self.compilation_config.use_inductor_graph_partition:
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
fx_split_ops: list[str] = []
else:
fx_split_ops = self.compilation_config.splitting_ops or []
resolved_split_ops = resolve_defined_ops(fx_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
from torch._dynamo.utils import lazy_format_graph_code

View File

@ -17,8 +17,6 @@ from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
from .inductor_pass import pass_context
class CompilerInterface:
"""
@ -209,13 +207,12 @@ class InductorStandaloneAdaptor(CompilerInterface):
from torch._inductor import standalone_compile
with pass_context(runtime_shape):
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
# Save the compiled artifact to disk in the specified path
assert key is not None
@ -462,13 +459,12 @@ class InductorAdaptor(CompilerInterface):
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
with pass_context(runtime_shape):
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config,
)
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config,
)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the

View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import contextlib
from typing import TYPE_CHECKING
from torch._library.utils import lookup_op
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
logger = init_logger(__name__)
def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]:
"""Resolve operator names to OpOverload objects.
Skips operators that fail to resolve (e.g., operators not registered or
model-specific operators not present in the current model).
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
Args:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
Returns:
List of successfully resolved operator overloads
"""
resolved = []
for op_name in op_names:
try:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
logger.warning(
"Failed to resolve operator for Inductor partition: %s", op_name
)
continue
return resolved
@contextlib.contextmanager
def inductor_partition_rule_context(overloads: list[torch._ops.OpOverload]):
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.
Args:
overloads: List of resolved operator overload objects.
"""
if not overloads:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
logger.debug("Restored previous partition rules state.")

View File

@ -209,8 +209,23 @@ class CompilationConfig:
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
The behavior depends on use_inductor_graph_partition:
- When use_inductor_graph_partition=False (default):
These ops are used for Dynamo FX-level graph splitting. The graph is
split at these ops before Inductor compilation, creating separate
subgraphs for cudagraph capture.
- When use_inductor_graph_partition=True:
These ops are used to register Inductor partition rules. The graph
partitioning happens at Inductor codegen time after all passes and
fusions are finished, allowing compilation and custom passes to operate
on the full graph while still excluding these ops from cudagraphs.
If None, defaults to attention ops for piecewise cudagraphs.
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
# Inductor capture
use_inductor: bool = True
@ -367,18 +382,19 @@ class CompilationConfig:
model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_mla_attention",
"vllm.unified_mla_attention_with_output",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer",
"vllm::unified_attention",
"vllm::unified_attention_with_output",
"vllm::unified_mla_attention",
"vllm::unified_mla_attention_with_output",
"vllm::mamba_mixer2",
"vllm::mamba_mixer",
"vllm::short_conv",
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention",
"vllm::sparse_attn_indexer",
]
def compute_hash(self) -> str:
@ -654,31 +670,25 @@ class CompilationConfig:
def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
'"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
"used to annotate custom ops for graph partition."
)
if self.splitting_ops is not None and len(self.splitting_ops) > 0:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
if self.splitting_ops is None:
self.splitting_ops = list(self._attention_ops)
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off."
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
# For dynamo-partition (non-inductor) attention fusion,
# set splitting_ops to empty to avoid splitting at attention ops
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
@ -691,23 +701,17 @@ class CompilationConfig:
)
def is_attention_compiled_piecewise(self) -> bool:
use_fx_graph_piecewise_compilation = (
self.level == CompilationLevel.PIECEWISE
and self.splitting_ops_contain_attention()
)
if not self.splitting_ops_contain_attention():
return False
inductor_used = (
self.level == CompilationLevel.PIECEWISE and self.use_inductor
) or (
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)
use_inductor_piecewise_compilation = (
inductor_used
and self.use_inductor_graph_partition
and not self.splitting_ops_contain_attention()
)
if not self.use_inductor_graph_partition:
# Dynamo-level FX split case
return self.level == CompilationLevel.PIECEWISE
return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation
# Inductor partition case
return (
self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor"
)
def custom_op_log_check(self):
"""