mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:24:29 +08:00
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e519281920
commit
cddce79fda
@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
@ -127,7 +127,7 @@ def _run_simple_model(
|
||||
@torch.inference_mode()
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
_run_simple_model(
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=False,
|
||||
use_inductor=use_inductor,
|
||||
# 2 * num_layers + 1
|
||||
@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
|
||||
@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
|
||||
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
@ -268,7 +268,7 @@ def run_model(
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
compilation_config.splitting_ops = ["silly.attention"]
|
||||
compilation_config.splitting_ops = ["silly::attention"]
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
@ -438,7 +438,7 @@ def benchmark():
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -4,10 +4,12 @@ import pytest
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.utils import _is_torch_equal_or_newer
|
||||
from vllm.config.compilation import CompilationLevel
|
||||
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def test_version():
|
||||
# Test the version comparison logic using the private function
|
||||
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
||||
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
||||
@ -17,6 +19,9 @@ def test_version():
|
||||
|
||||
def test_use_cudagraphs_dynamic():
|
||||
vllm_config = VllmConfig()
|
||||
# Default V1 configuration now starts without cudagraphs enabled; the
|
||||
# engine decides when to capture based on runtime settings instead of a
|
||||
# blanket default.
|
||||
assert vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
|
||||
def test_splitting_ops_dynamic():
|
||||
# Default config
|
||||
config = VllmConfig()
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# Default V1 config leaves cudagraph mode unset; splitting ops are only
|
||||
# populated when the engine decides to use piecewise compilation.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
assert not config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||
# inductor graph partition is only available in PyTorch 2.9+.
|
||||
# this is a fast config check so we are not using pytest.skip.
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
)
|
||||
# should ignore splitting_ops
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
|
||||
# When attn_fusion pass enabled.
|
||||
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(AssertionError):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
)
|
||||
)
|
||||
# With the new simplified logic, attention fusion works with splitting_ops
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# cudagraph mode remains PIECEWISE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if _is_torch_equal_or_newer("2.9.0.dev"):
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# enable_attn_fusion is directly support under
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_resolve_operator_overload():
|
||||
import torch
|
||||
|
||||
from vllm.compilation.partition_rules import resolve_defined_ops
|
||||
|
||||
# Test valid operator names
|
||||
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
|
||||
assert len(resolved) == 2
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
# Test that invalid operators are skipped (not raising exceptions)
|
||||
resolved = resolve_defined_ops(
|
||||
[
|
||||
"aten::mm.default",
|
||||
"aten::nonexistent_op.default", # This should be skipped
|
||||
"aten::addmm.default",
|
||||
]
|
||||
)
|
||||
assert len(resolved) == 2 # Only 2 valid ops
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator():
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
)
|
||||
@ -186,7 +186,7 @@ def test_conditional_compile_enable_if():
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
),
|
||||
)
|
||||
@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
),
|
||||
)
|
||||
|
||||
@ -15,6 +15,11 @@ import torch.fx as fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
from vllm.compilation.partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
resolve_defined_ops,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -76,6 +81,21 @@ class CompilerManager:
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
@contextmanager
|
||||
def compile_context(self, runtime_shape: Optional[int] = None):
|
||||
"""Provide compilation context for the duration of compilation to set
|
||||
any torch global properties we want to scope to a single Inductor
|
||||
compilation (e.g. partition rules, pass context)."""
|
||||
with pass_context(runtime_shape):
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
inductor_partition_ops = resolve_defined_ops(
|
||||
self.compilation_config.splitting_ops
|
||||
)
|
||||
with inductor_partition_rule_context(inductor_partition_ops):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
@ -197,9 +217,15 @@ class CompilerManager:
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
|
||||
)
|
||||
|
||||
with self.compile_context(runtime_shape):
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
runtime_shape,
|
||||
maybe_key,
|
||||
)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
@ -258,7 +284,7 @@ class SplitItem:
|
||||
|
||||
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, ops: list[str]
|
||||
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
@ -267,7 +293,12 @@ def split_graph(
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
if node.op == "call_function" and str(node.target) in ops:
|
||||
# Match node.target against resolved_ops
|
||||
# node.target can be OpOverloadPacket, need to check .default
|
||||
if node.op == "call_function" and (
|
||||
node.target in resolved_ops
|
||||
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
|
||||
):
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
@ -615,9 +646,14 @@ class VllmBackend:
|
||||
self.graph = graph
|
||||
self.configure_post_pass()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_config.splitting_ops
|
||||
)
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
|
||||
fx_split_ops: list[str] = []
|
||||
else:
|
||||
fx_split_ops = self.compilation_config.splitting_ops or []
|
||||
|
||||
resolved_split_ops = resolve_defined_ops(fx_split_ops)
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
|
||||
@ -17,8 +17,6 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
from .inductor_pass import pass_context
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
@ -209,13 +207,12 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
@ -462,13 +459,12 @@ class InductorAdaptor(CompilerInterface):
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
|
||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||
# compilation cache. So turn off the checks if we disable the
|
||||
|
||||
95
vllm/compilation/partition_rules.py
Normal file
95
vllm/compilation/partition_rules.py
Normal file
@ -0,0 +1,95 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch._library.utils import lookup_op
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]:
|
||||
"""Resolve operator names to OpOverload objects.
|
||||
|
||||
Skips operators that fail to resolve (e.g., operators not registered or
|
||||
model-specific operators not present in the current model).
|
||||
|
||||
Note: Users should inspect the operator graph before lowering and ensure
|
||||
the specified operators are present in the final graph. Built-in PyTorch
|
||||
operators (aten::*, torch::*) may be decomposed, fused, or transformed
|
||||
during Inductor's compilation passes, so use them with caution.
|
||||
|
||||
Args:
|
||||
op_names: List of operator names in PyTorch format
|
||||
(e.g., "vllm::unified_attention")
|
||||
|
||||
Returns:
|
||||
List of successfully resolved operator overloads
|
||||
"""
|
||||
resolved = []
|
||||
for op_name in op_names:
|
||||
try:
|
||||
resolved.append(lookup_op(op_name))
|
||||
except Exception:
|
||||
# Skip operators that don't exist (e.g., model-specific ops)
|
||||
logger.warning(
|
||||
"Failed to resolve operator for Inductor partition: %s", op_name
|
||||
)
|
||||
continue
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inductor_partition_rule_context(overloads: list[torch._ops.OpOverload]):
|
||||
"""Context manager to temporarily register Inductor partition rules.
|
||||
|
||||
Registers custom partition rules for specified operators, forcing the
|
||||
Inductor scheduler to partition the graph at these operators. The rules
|
||||
are automatically restored to their previous state on exit.
|
||||
|
||||
Note: Callers should use resolve_defined_ops() to convert operator names
|
||||
to OpOverload objects before calling this function.
|
||||
|
||||
Args:
|
||||
overloads: List of resolved operator overload objects.
|
||||
"""
|
||||
if not overloads:
|
||||
logger.debug("No partition ops provided; skipping rule registration.")
|
||||
yield
|
||||
return
|
||||
|
||||
from torch._inductor.scheduler import ( # type: ignore
|
||||
_custom_should_partition_fns,
|
||||
register_should_partition_rule,
|
||||
)
|
||||
|
||||
def _always_partition(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
# Save current state before registering
|
||||
saved_rules = _custom_should_partition_fns.copy()
|
||||
|
||||
for overload in overloads:
|
||||
register_should_partition_rule(
|
||||
overload,
|
||||
_always_partition,
|
||||
)
|
||||
|
||||
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clear and restore previous state
|
||||
_custom_should_partition_fns.clear()
|
||||
_custom_should_partition_fns.update(saved_rules)
|
||||
logger.debug("Restored previous partition rules state.")
|
||||
@ -209,8 +209,23 @@ class CompilationConfig:
|
||||
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||
splitting_ops: Optional[list[str]] = None
|
||||
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
||||
compilation."""
|
||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||
|
||||
The behavior depends on use_inductor_graph_partition:
|
||||
|
||||
- When use_inductor_graph_partition=False (default):
|
||||
These ops are used for Dynamo FX-level graph splitting. The graph is
|
||||
split at these ops before Inductor compilation, creating separate
|
||||
subgraphs for cudagraph capture.
|
||||
|
||||
- When use_inductor_graph_partition=True:
|
||||
These ops are used to register Inductor partition rules. The graph
|
||||
partitioning happens at Inductor codegen time after all passes and
|
||||
fusions are finished, allowing compilation and custom passes to operate
|
||||
on the full graph while still excluding these ops from cudagraphs.
|
||||
|
||||
If None, defaults to attention ops for piecewise cudagraphs.
|
||||
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
|
||||
|
||||
# Inductor capture
|
||||
use_inductor: bool = True
|
||||
@ -367,18 +382,19 @@ class CompilationConfig:
|
||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||
|
||||
# Attention ops; used for piecewise cudagraphs
|
||||
# Use PyTorch operator format: "namespace::name"
|
||||
_attention_ops: ClassVar[list[str]] = [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.unified_mla_attention",
|
||||
"vllm.unified_mla_attention_with_output",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.sparse_attn_indexer",
|
||||
"vllm::unified_attention",
|
||||
"vllm::unified_attention_with_output",
|
||||
"vllm::unified_mla_attention",
|
||||
"vllm::unified_mla_attention_with_output",
|
||||
"vllm::mamba_mixer2",
|
||||
"vllm::mamba_mixer",
|
||||
"vllm::short_conv",
|
||||
"vllm::linear_attention",
|
||||
"vllm::plamo2_mamba_mixer",
|
||||
"vllm::gdn_attention",
|
||||
"vllm::sparse_attn_indexer",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -654,31 +670,25 @@ class CompilationConfig:
|
||||
|
||||
def set_splitting_ops_for_inductor_graph_partition(self):
|
||||
assert self.use_inductor_graph_partition
|
||||
use_inductor_graph_partition_msg = (
|
||||
"When use_inductor_graph_partition=True, splitting_ops "
|
||||
"are ignored and set to an empty list. Instead, "
|
||||
'"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
|
||||
"used to annotate custom ops for graph partition."
|
||||
)
|
||||
if self.splitting_ops is not None and len(self.splitting_ops) > 0:
|
||||
logger.warning_once(use_inductor_graph_partition_msg)
|
||||
self.splitting_ops = []
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.enable_attn_fusion
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off."
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
# For dynamo-partition (non-inductor) attention fusion,
|
||||
# set splitting_ops to empty to avoid splitting at attention ops
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
@ -691,23 +701,17 @@ class CompilationConfig:
|
||||
)
|
||||
|
||||
def is_attention_compiled_piecewise(self) -> bool:
|
||||
use_fx_graph_piecewise_compilation = (
|
||||
self.level == CompilationLevel.PIECEWISE
|
||||
and self.splitting_ops_contain_attention()
|
||||
)
|
||||
if not self.splitting_ops_contain_attention():
|
||||
return False
|
||||
|
||||
inductor_used = (
|
||||
self.level == CompilationLevel.PIECEWISE and self.use_inductor
|
||||
) or (
|
||||
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
|
||||
)
|
||||
use_inductor_piecewise_compilation = (
|
||||
inductor_used
|
||||
and self.use_inductor_graph_partition
|
||||
and not self.splitting_ops_contain_attention()
|
||||
)
|
||||
if not self.use_inductor_graph_partition:
|
||||
# Dynamo-level FX split case
|
||||
return self.level == CompilationLevel.PIECEWISE
|
||||
|
||||
return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation
|
||||
# Inductor partition case
|
||||
return (
|
||||
self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor"
|
||||
)
|
||||
|
||||
def custom_op_log_check(self):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user