[torch.compile] Make inductor partition rules respect splitting_ops #25691 (#25845)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
baonudesifeizhai 2025-10-10 12:35:28 -04:00 committed by GitHub
parent e519281920
commit cddce79fda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 267 additions and 112 deletions

View File

@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
) )
) )
@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=False, use_cudagraph=False,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
) )
) )
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE

View File

@ -127,7 +127,7 @@ def _run_simple_model(
@torch.inference_mode() @torch.inference_mode()
def test_simple_piecewise_compile(use_inductor): def test_simple_piecewise_compile(use_inductor):
_run_simple_model( _run_simple_model(
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
use_inductor_graph_partition=False, use_inductor_graph_partition=False,
use_inductor=use_inductor, use_inductor=use_inductor,
# 2 * num_layers + 1 # 2 * num_layers + 1
@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor):
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []]) @pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch): def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
if not is_torch_equal_or_newer("2.9.0.dev"): if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+") pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

View File

@ -268,7 +268,7 @@ def run_model(
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
) )
if split_attn: if split_attn:
compilation_config.splitting_ops = ["silly.attention"] compilation_config.splitting_ops = ["silly::attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
@ -438,7 +438,7 @@ def benchmark():
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes, cudagraph_capture_sizes=cudagraph_sizes,
) )
else: else:

View File

@ -4,10 +4,12 @@ import pytest
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.utils import _is_torch_equal_or_newer from vllm.config.compilation import CompilationLevel
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
def test_version(): def test_version():
# Test the version comparison logic using the private function
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
@ -17,6 +19,9 @@ def test_version():
def test_use_cudagraphs_dynamic(): def test_use_cudagraphs_dynamic():
vllm_config = VllmConfig() vllm_config = VllmConfig()
# Default V1 configuration now starts without cudagraphs enabled; the
# engine decides when to capture based on runtime settings instead of a
# blanket default.
assert vllm_config.compilation_config.use_cudagraph assert vllm_config.compilation_config.use_cudagraph
@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
def test_splitting_ops_dynamic(): def test_splitting_ops_dynamic():
# Default config # Default config
config = VllmConfig() config = VllmConfig()
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE # Default V1 config leaves cudagraph mode unset; splitting ops are only
assert config.compilation_config.splitting_ops_contain_attention() # populated when the engine decides to use piecewise compilation.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True # When use_inductor_graph_partition=True
if _is_torch_equal_or_newer("2.9.0.dev"): if is_torch_equal_or_newer("2.9.0.dev"):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
config = VllmConfig( config = VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
use_inductor_graph_partition=True, splitting_ops=["silly_attention"] level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
) )
) )
# should ignore splitting_ops # with inductor partition we use splitting_ops directly for
assert config.compilation_config.splitting_ops == [] # partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled. # When attn_fusion pass enabled, splitting_ops now default to attention ops.
config = VllmConfig( config = VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config={"enable_attn_fusion": True, "enable_noop": True}, pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE,
) )
) )
assert config.compilation_config.splitting_ops == [] # With the new simplified logic, attention fusion works with splitting_ops
# cudagraph mode also fall back to FULL assert config.compilation_config.splitting_ops_contain_attention()
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL # cudagraph mode remains PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(
compilation_config=CompilationConfig(
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# When both use_inductor_graph_partition and attn_fusion pass enabled. # When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer("2.9.0.dev"): if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig( config = VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True}, pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE,
) )
) )
assert config.compilation_config.splitting_ops == [] # With inductor graph partition, attn_fusion and splitting_ops
# enable_attn_fusion is directly support under # work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode # use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged. # is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload():
import torch
from vllm.compilation.partition_rules import resolve_defined_ops
# Test valid operator names
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
assert len(resolved) == 2
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# Test that invalid operators are skipped (not raising exceptions)
resolved = resolve_defined_ops(
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
)
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default

View File

@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator():
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
) )
) )
@ -186,7 +186,7 @@ def test_conditional_compile_enable_if():
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
), ),
) )
@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
), ),
) )

View File

@ -15,6 +15,11 @@ import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
inductor_partition_rule_context,
resolve_defined_ops,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -76,6 +81,21 @@ class CompilerManager:
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config) return self.compiler.compute_hash(vllm_config)
@contextmanager
def compile_context(self, runtime_shape: Optional[int] = None):
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
if self.compilation_config.use_inductor_graph_partition:
inductor_partition_ops = resolve_defined_ops(
self.compilation_config.splitting_ops
)
with inductor_partition_rule_context(inductor_partition_ops):
yield
else:
yield
def initialize_cache( def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = "" self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
): ):
@ -197,9 +217,15 @@ class CompilerManager:
maybe_key = None maybe_key = None
else: else:
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key with self.compile_context(runtime_shape):
) compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
runtime_shape,
maybe_key,
)
assert compiled_graph is not None, "Failed to compile the graph" assert compiled_graph is not None, "Failed to compile the graph"
@ -258,7 +284,7 @@ class SplitItem:
def split_graph( def split_graph(
graph: fx.GraphModule, ops: list[str] graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
) -> tuple[fx.GraphModule, list[SplitItem]]: ) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops # split graph by ops
subgraph_id = 0 subgraph_id = 0
@ -267,7 +293,12 @@ def split_graph(
for node in graph.graph.nodes: for node in graph.graph.nodes:
if node.op in ("output", "placeholder"): if node.op in ("output", "placeholder"):
continue continue
if node.op == "call_function" and str(node.target) in ops: # Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default
if node.op == "call_function" and (
node.target in resolved_ops
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
):
subgraph_id += 1 subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id) split_op_graphs.append(subgraph_id)
@ -615,9 +646,14 @@ class VllmBackend:
self.graph = graph self.graph = graph
self.configure_post_pass() self.configure_post_pass()
self.split_gm, self.piecewise_graphs = split_graph( if self.compilation_config.use_inductor_graph_partition:
graph, self.compilation_config.splitting_ops # Let Inductor decide partitioning; avoid FX-level pre-splitting.
) fx_split_ops: list[str] = []
else:
fx_split_ops = self.compilation_config.splitting_ops or []
resolved_split_ops = resolve_defined_ops(fx_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code

View File

@ -17,8 +17,6 @@ from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
from .inductor_pass import pass_context
class CompilerInterface: class CompilerInterface:
""" """
@ -209,13 +207,12 @@ class InductorStandaloneAdaptor(CompilerInterface):
from torch._inductor import standalone_compile from torch._inductor import standalone_compile
with pass_context(runtime_shape): compiled_graph = standalone_compile(
compiled_graph = standalone_compile( graph,
graph, example_inputs,
example_inputs, dynamic_shapes=dynamic_shapes,
dynamic_shapes=dynamic_shapes, options={"config_patches": current_config},
options={"config_patches": current_config}, )
)
# Save the compiled artifact to disk in the specified path # Save the compiled artifact to disk in the specified path
assert key is not None assert key is not None
@ -462,13 +459,12 @@ class InductorAdaptor(CompilerInterface):
torch._functorch.config.patch(enable_remote_autograd_cache=False) torch._functorch.config.patch(enable_remote_autograd_cache=False)
) )
with pass_context(runtime_shape): compiled_graph = compile_fx(
compiled_graph = compile_fx( graph,
graph, example_inputs,
example_inputs, inner_compile=hijacked_compile_fx_inner,
inner_compile=hijacked_compile_fx_inner, config_patches=current_config,
config_patches=current_config, )
)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the # compilation cache. So turn off the checks if we disable the

View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import contextlib
from typing import TYPE_CHECKING
from torch._library.utils import lookup_op
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
logger = init_logger(__name__)
def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]:
"""Resolve operator names to OpOverload objects.
Skips operators that fail to resolve (e.g., operators not registered or
model-specific operators not present in the current model).
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
Args:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
Returns:
List of successfully resolved operator overloads
"""
resolved = []
for op_name in op_names:
try:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
logger.warning(
"Failed to resolve operator for Inductor partition: %s", op_name
)
continue
return resolved
@contextlib.contextmanager
def inductor_partition_rule_context(overloads: list[torch._ops.OpOverload]):
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.
Args:
overloads: List of resolved operator overload objects.
"""
if not overloads:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
logger.debug("Restored previous partition rules state.")

View File

@ -209,8 +209,23 @@ class CompilationConfig:
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops.""" Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: Optional[list[str]] = None splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise """A list of ops to exclude from cudagraphs, used in piecewise compilation.
compilation."""
The behavior depends on use_inductor_graph_partition:
- When use_inductor_graph_partition=False (default):
These ops are used for Dynamo FX-level graph splitting. The graph is
split at these ops before Inductor compilation, creating separate
subgraphs for cudagraph capture.
- When use_inductor_graph_partition=True:
These ops are used to register Inductor partition rules. The graph
partitioning happens at Inductor codegen time after all passes and
fusions are finished, allowing compilation and custom passes to operate
on the full graph while still excluding these ops from cudagraphs.
If None, defaults to attention ops for piecewise cudagraphs.
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
# Inductor capture # Inductor capture
use_inductor: bool = True use_inductor: bool = True
@ -367,18 +382,19 @@ class CompilationConfig:
model code, e.g., Attention, FusedMOE when dp_size>1.""" model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs # Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [ _attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention", "vllm::unified_attention",
"vllm.unified_attention_with_output", "vllm::unified_attention_with_output",
"vllm.unified_mla_attention", "vllm::unified_mla_attention",
"vllm.unified_mla_attention_with_output", "vllm::unified_mla_attention_with_output",
"vllm.mamba_mixer2", "vllm::mamba_mixer2",
"vllm.mamba_mixer", "vllm::mamba_mixer",
"vllm.short_conv", "vllm::short_conv",
"vllm.linear_attention", "vllm::linear_attention",
"vllm.plamo2_mamba_mixer", "vllm::plamo2_mamba_mixer",
"vllm.gdn_attention", "vllm::gdn_attention",
"vllm.sparse_attn_indexer", "vllm::sparse_attn_indexer",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -654,31 +670,25 @@ class CompilationConfig:
def set_splitting_ops_for_inductor_graph_partition(self): def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition assert self.use_inductor_graph_partition
use_inductor_graph_partition_msg = ( if self.splitting_ops is None:
"When use_inductor_graph_partition=True, splitting_ops " self.splitting_ops = list(self._attention_ops)
"are ignored and set to an empty list. Instead, "
'"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
"used to annotate custom ops for graph partition."
)
if self.splitting_ops is not None and len(self.splitting_ops) > 0:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
def set_splitting_ops_for_attn_fusion(self): def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion assert self.pass_config.enable_attn_fusion
if self.splitting_ops is None: # For dynamo-partition (non-inductor) attention fusion,
self.splitting_ops = [] # set splitting_ops to empty to avoid splitting at attention ops
if self.cudagraph_mode.has_piecewise_cudagraphs(): self.splitting_ops = []
logger.warning_once( if self.cudagraph_mode.has_piecewise_cudagraphs():
"enable_attn_fusion is incompatible with piecewise " logger.warning_once(
"cudagraph when use_inductor_graph_partition is off." "enable_attn_fusion is incompatible with piecewise "
"In this case, splitting_ops will be set to empty " "cudagraph when use_inductor_graph_partition is off. "
"list, and cudagraph_mode will be set to FULL. " "In this case, splitting_ops will be set to empty "
"Please ensure you are using attention backends that " "list, and cudagraph_mode will be set to FULL. "
"support cudagraph or set cudagraph_mode to NONE " "Please ensure you are using attention backends that "
"explicitly if encountering any problems." "support cudagraph or set cudagraph_mode to NONE "
) "explicitly if encountering any problems."
self.cudagraph_mode = CUDAGraphMode.FULL )
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), ( assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops " "attention ops should not be in splitting_ops "
@ -691,23 +701,17 @@ class CompilationConfig:
) )
def is_attention_compiled_piecewise(self) -> bool: def is_attention_compiled_piecewise(self) -> bool:
use_fx_graph_piecewise_compilation = ( if not self.splitting_ops_contain_attention():
self.level == CompilationLevel.PIECEWISE return False
and self.splitting_ops_contain_attention()
)
inductor_used = ( if not self.use_inductor_graph_partition:
self.level == CompilationLevel.PIECEWISE and self.use_inductor # Dynamo-level FX split case
) or ( return self.level == CompilationLevel.PIECEWISE
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)
use_inductor_piecewise_compilation = (
inductor_used
and self.use_inductor_graph_partition
and not self.splitting_ops_contain_attention()
)
return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation # Inductor partition case
return (
self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor"
)
def custom_op_log_check(self): def custom_op_log_check(self):
""" """