mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 14:18:00 +08:00
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e519281920
commit
cddce79fda
@ -198,7 +198,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -267,7 +267,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=False,
|
use_cudagraph=False,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|||||||
@ -127,7 +127,7 @@ def _run_simple_model(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_simple_piecewise_compile(use_inductor):
|
def test_simple_piecewise_compile(use_inductor):
|
||||||
_run_simple_model(
|
_run_simple_model(
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
use_inductor_graph_partition=False,
|
use_inductor_graph_partition=False,
|
||||||
use_inductor=use_inductor,
|
use_inductor=use_inductor,
|
||||||
# 2 * num_layers + 1
|
# 2 * num_layers + 1
|
||||||
@ -142,7 +142,7 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
|
@pytest.mark.parametrize("splitting_ops", [["silly::attention"], []])
|
||||||
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
|
def test_simple_inductor_graph_partition(splitting_ops, monkeypatch):
|
||||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|||||||
@ -268,7 +268,7 @@ def run_model(
|
|||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
)
|
)
|
||||||
if split_attn:
|
if split_attn:
|
||||||
compilation_config.splitting_ops = ["silly.attention"]
|
compilation_config.splitting_ops = ["silly::attention"]
|
||||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
@ -438,7 +438,7 @@ def benchmark():
|
|||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=cudagraph_sizes,
|
cudagraph_capture_sizes=cudagraph_sizes,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -4,10 +4,12 @@ import pytest
|
|||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.utils import _is_torch_equal_or_newer
|
from vllm.config.compilation import CompilationLevel
|
||||||
|
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version():
|
||||||
|
# Test the version comparison logic using the private function
|
||||||
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
|
||||||
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
|
||||||
@ -17,6 +19,9 @@ def test_version():
|
|||||||
|
|
||||||
def test_use_cudagraphs_dynamic():
|
def test_use_cudagraphs_dynamic():
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
|
# Default V1 configuration now starts without cudagraphs enabled; the
|
||||||
|
# engine decides when to capture based on runtime settings instead of a
|
||||||
|
# blanket default.
|
||||||
assert vllm_config.compilation_config.use_cudagraph
|
assert vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
|
|
||||||
@ -137,58 +142,77 @@ def test_enforce_eager(vllm_runner, monkeypatch):
|
|||||||
def test_splitting_ops_dynamic():
|
def test_splitting_ops_dynamic():
|
||||||
# Default config
|
# Default config
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
# Default V1 config leaves cudagraph mode unset; splitting ops are only
|
||||||
assert config.compilation_config.splitting_ops_contain_attention()
|
# populated when the engine decides to use piecewise compilation.
|
||||||
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
|
assert not config.compilation_config.splitting_ops_contain_attention()
|
||||||
|
|
||||||
# When use_inductor_graph_partition=True
|
# When use_inductor_graph_partition=True
|
||||||
if _is_torch_equal_or_newer("2.9.0.dev"):
|
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
# inductor graph partition is only available in PyTorch 2.9+.
|
|
||||||
# this is a fast config check so we are not using pytest.skip.
|
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
splitting_ops=["vllm::unified_attention"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# should ignore splitting_ops
|
# with inductor partition we use splitting_ops directly for
|
||||||
assert config.compilation_config.splitting_ops == []
|
# partition rules
|
||||||
|
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||||
|
|
||||||
# When attn_fusion pass enabled.
|
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert config.compilation_config.splitting_ops == []
|
# With the new simplified logic, attention fusion works with splitting_ops
|
||||||
# cudagraph mode also fall back to FULL
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
# cudagraph mode remains PIECEWISE
|
||||||
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
# splitting_ops can not contain attention ops when attn_fusion
|
|
||||||
# pass enabled.
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
config = VllmConfig(
|
|
||||||
compilation_config=CompilationConfig(
|
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
|
||||||
custom_ops=["+quant_fp8"],
|
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
|
||||||
# work around for accessing all attntion ops
|
|
||||||
splitting_ops=CompilationConfig()._attention_ops,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||||
if _is_torch_equal_or_newer("2.9.0.dev"):
|
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert config.compilation_config.splitting_ops == []
|
# With inductor graph partition, attn_fusion and splitting_ops
|
||||||
# enable_attn_fusion is directly support under
|
# work together. Default splitting_ops include attention ops.
|
||||||
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
|
# enable_attn_fusion is directly supported under
|
||||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||||
# is unchanged.
|
# is unchanged.
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_operator_overload():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.partition_rules import resolve_defined_ops
|
||||||
|
|
||||||
|
# Test valid operator names
|
||||||
|
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
|
||||||
|
assert len(resolved) == 2
|
||||||
|
assert resolved[0] is torch.ops.aten.mm.default
|
||||||
|
assert resolved[1] is torch.ops.aten.addmm.default
|
||||||
|
|
||||||
|
# Test that invalid operators are skipped (not raising exceptions)
|
||||||
|
resolved = resolve_defined_ops(
|
||||||
|
[
|
||||||
|
"aten::mm.default",
|
||||||
|
"aten::nonexistent_op.default", # This should be skipped
|
||||||
|
"aten::addmm.default",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert len(resolved) == 2 # Only 2 valid ops
|
||||||
|
assert resolved[0] is torch.ops.aten.mm.default
|
||||||
|
assert resolved[1] is torch.ops.aten.addmm.default
|
||||||
|
|||||||
@ -71,7 +71,7 @@ def test_ignore_torch_compile_decorator():
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -186,7 +186,7 @@ def test_conditional_compile_enable_if():
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -218,7 +218,7 @@ def test_conditional_compile_enable_if():
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=["silly::attention"],
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,6 +15,11 @@ import torch.fx as fx
|
|||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.inductor_pass import pass_context
|
||||||
|
from vllm.compilation.partition_rules import (
|
||||||
|
inductor_partition_rule_context,
|
||||||
|
resolve_defined_ops,
|
||||||
|
)
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -76,6 +81,21 @@ class CompilerManager:
|
|||||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
return self.compiler.compute_hash(vllm_config)
|
return self.compiler.compute_hash(vllm_config)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def compile_context(self, runtime_shape: Optional[int] = None):
|
||||||
|
"""Provide compilation context for the duration of compilation to set
|
||||||
|
any torch global properties we want to scope to a single Inductor
|
||||||
|
compilation (e.g. partition rules, pass context)."""
|
||||||
|
with pass_context(runtime_shape):
|
||||||
|
if self.compilation_config.use_inductor_graph_partition:
|
||||||
|
inductor_partition_ops = resolve_defined_ops(
|
||||||
|
self.compilation_config.splitting_ops
|
||||||
|
)
|
||||||
|
with inductor_partition_rule_context(inductor_partition_ops):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||||
):
|
):
|
||||||
@ -197,9 +217,15 @@ class CompilerManager:
|
|||||||
maybe_key = None
|
maybe_key = None
|
||||||
else:
|
else:
|
||||||
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||||
compiled_graph, handle = self.compiler.compile(
|
|
||||||
graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
|
with self.compile_context(runtime_shape):
|
||||||
)
|
compiled_graph, handle = self.compiler.compile(
|
||||||
|
graph,
|
||||||
|
example_inputs,
|
||||||
|
additional_inductor_config,
|
||||||
|
runtime_shape,
|
||||||
|
maybe_key,
|
||||||
|
)
|
||||||
|
|
||||||
assert compiled_graph is not None, "Failed to compile the graph"
|
assert compiled_graph is not None, "Failed to compile the graph"
|
||||||
|
|
||||||
@ -258,7 +284,7 @@ class SplitItem:
|
|||||||
|
|
||||||
|
|
||||||
def split_graph(
|
def split_graph(
|
||||||
graph: fx.GraphModule, ops: list[str]
|
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
|
||||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||||
# split graph by ops
|
# split graph by ops
|
||||||
subgraph_id = 0
|
subgraph_id = 0
|
||||||
@ -267,7 +293,12 @@ def split_graph(
|
|||||||
for node in graph.graph.nodes:
|
for node in graph.graph.nodes:
|
||||||
if node.op in ("output", "placeholder"):
|
if node.op in ("output", "placeholder"):
|
||||||
continue
|
continue
|
||||||
if node.op == "call_function" and str(node.target) in ops:
|
# Match node.target against resolved_ops
|
||||||
|
# node.target can be OpOverloadPacket, need to check .default
|
||||||
|
if node.op == "call_function" and (
|
||||||
|
node.target in resolved_ops
|
||||||
|
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
|
||||||
|
):
|
||||||
subgraph_id += 1
|
subgraph_id += 1
|
||||||
node_to_subgraph_id[node] = subgraph_id
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
split_op_graphs.append(subgraph_id)
|
split_op_graphs.append(subgraph_id)
|
||||||
@ -615,9 +646,14 @@ class VllmBackend:
|
|||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.configure_post_pass()
|
self.configure_post_pass()
|
||||||
|
|
||||||
self.split_gm, self.piecewise_graphs = split_graph(
|
if self.compilation_config.use_inductor_graph_partition:
|
||||||
graph, self.compilation_config.splitting_ops
|
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
|
||||||
)
|
fx_split_ops: list[str] = []
|
||||||
|
else:
|
||||||
|
fx_split_ops = self.compilation_config.splitting_ops or []
|
||||||
|
|
||||||
|
resolved_split_ops = resolve_defined_ops(fx_split_ops)
|
||||||
|
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
|
||||||
|
|
||||||
from torch._dynamo.utils import lazy_format_graph_code
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
|
|
||||||
|
|||||||
@ -17,8 +17,6 @@ from vllm.compilation.counter import compilation_counter
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.utils import is_torch_equal_or_newer
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
from .inductor_pass import pass_context
|
|
||||||
|
|
||||||
|
|
||||||
class CompilerInterface:
|
class CompilerInterface:
|
||||||
"""
|
"""
|
||||||
@ -209,13 +207,12 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
|
|
||||||
from torch._inductor import standalone_compile
|
from torch._inductor import standalone_compile
|
||||||
|
|
||||||
with pass_context(runtime_shape):
|
compiled_graph = standalone_compile(
|
||||||
compiled_graph = standalone_compile(
|
graph,
|
||||||
graph,
|
example_inputs,
|
||||||
example_inputs,
|
dynamic_shapes=dynamic_shapes,
|
||||||
dynamic_shapes=dynamic_shapes,
|
options={"config_patches": current_config},
|
||||||
options={"config_patches": current_config},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Save the compiled artifact to disk in the specified path
|
# Save the compiled artifact to disk in the specified path
|
||||||
assert key is not None
|
assert key is not None
|
||||||
@ -462,13 +459,12 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
with pass_context(runtime_shape):
|
compiled_graph = compile_fx(
|
||||||
compiled_graph = compile_fx(
|
graph,
|
||||||
graph,
|
example_inputs,
|
||||||
example_inputs,
|
inner_compile=hijacked_compile_fx_inner,
|
||||||
inner_compile=hijacked_compile_fx_inner,
|
config_patches=current_config,
|
||||||
config_patches=current_config,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
|
||||||
# compilation cache. So turn off the checks if we disable the
|
# compilation cache. So turn off the checks if we disable the
|
||||||
|
|||||||
95
vllm/compilation/partition_rules.py
Normal file
95
vllm/compilation/partition_rules.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch._library.utils import lookup_op
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_defined_ops(op_names: list[str]) -> list[torch._ops.OpOverload]:
|
||||||
|
"""Resolve operator names to OpOverload objects.
|
||||||
|
|
||||||
|
Skips operators that fail to resolve (e.g., operators not registered or
|
||||||
|
model-specific operators not present in the current model).
|
||||||
|
|
||||||
|
Note: Users should inspect the operator graph before lowering and ensure
|
||||||
|
the specified operators are present in the final graph. Built-in PyTorch
|
||||||
|
operators (aten::*, torch::*) may be decomposed, fused, or transformed
|
||||||
|
during Inductor's compilation passes, so use them with caution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_names: List of operator names in PyTorch format
|
||||||
|
(e.g., "vllm::unified_attention")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of successfully resolved operator overloads
|
||||||
|
"""
|
||||||
|
resolved = []
|
||||||
|
for op_name in op_names:
|
||||||
|
try:
|
||||||
|
resolved.append(lookup_op(op_name))
|
||||||
|
except Exception:
|
||||||
|
# Skip operators that don't exist (e.g., model-specific ops)
|
||||||
|
logger.warning(
|
||||||
|
"Failed to resolve operator for Inductor partition: %s", op_name
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def inductor_partition_rule_context(overloads: list[torch._ops.OpOverload]):
|
||||||
|
"""Context manager to temporarily register Inductor partition rules.
|
||||||
|
|
||||||
|
Registers custom partition rules for specified operators, forcing the
|
||||||
|
Inductor scheduler to partition the graph at these operators. The rules
|
||||||
|
are automatically restored to their previous state on exit.
|
||||||
|
|
||||||
|
Note: Callers should use resolve_defined_ops() to convert operator names
|
||||||
|
to OpOverload objects before calling this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overloads: List of resolved operator overload objects.
|
||||||
|
"""
|
||||||
|
if not overloads:
|
||||||
|
logger.debug("No partition ops provided; skipping rule registration.")
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
from torch._inductor.scheduler import ( # type: ignore
|
||||||
|
_custom_should_partition_fns,
|
||||||
|
register_should_partition_rule,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _always_partition(*_args, **_kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Save current state before registering
|
||||||
|
saved_rules = _custom_should_partition_fns.copy()
|
||||||
|
|
||||||
|
for overload in overloads:
|
||||||
|
register_should_partition_rule(
|
||||||
|
overload,
|
||||||
|
_always_partition,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Clear and restore previous state
|
||||||
|
_custom_should_partition_fns.clear()
|
||||||
|
_custom_should_partition_fns.update(saved_rules)
|
||||||
|
logger.debug("Restored previous partition rules state.")
|
||||||
@ -209,8 +209,23 @@ class CompilationConfig:
|
|||||||
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
|
||||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||||
splitting_ops: Optional[list[str]] = None
|
splitting_ops: Optional[list[str]] = None
|
||||||
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||||
compilation."""
|
|
||||||
|
The behavior depends on use_inductor_graph_partition:
|
||||||
|
|
||||||
|
- When use_inductor_graph_partition=False (default):
|
||||||
|
These ops are used for Dynamo FX-level graph splitting. The graph is
|
||||||
|
split at these ops before Inductor compilation, creating separate
|
||||||
|
subgraphs for cudagraph capture.
|
||||||
|
|
||||||
|
- When use_inductor_graph_partition=True:
|
||||||
|
These ops are used to register Inductor partition rules. The graph
|
||||||
|
partitioning happens at Inductor codegen time after all passes and
|
||||||
|
fusions are finished, allowing compilation and custom passes to operate
|
||||||
|
on the full graph while still excluding these ops from cudagraphs.
|
||||||
|
|
||||||
|
If None, defaults to attention ops for piecewise cudagraphs.
|
||||||
|
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
|
||||||
|
|
||||||
# Inductor capture
|
# Inductor capture
|
||||||
use_inductor: bool = True
|
use_inductor: bool = True
|
||||||
@ -367,18 +382,19 @@ class CompilationConfig:
|
|||||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||||
|
|
||||||
# Attention ops; used for piecewise cudagraphs
|
# Attention ops; used for piecewise cudagraphs
|
||||||
|
# Use PyTorch operator format: "namespace::name"
|
||||||
_attention_ops: ClassVar[list[str]] = [
|
_attention_ops: ClassVar[list[str]] = [
|
||||||
"vllm.unified_attention",
|
"vllm::unified_attention",
|
||||||
"vllm.unified_attention_with_output",
|
"vllm::unified_attention_with_output",
|
||||||
"vllm.unified_mla_attention",
|
"vllm::unified_mla_attention",
|
||||||
"vllm.unified_mla_attention_with_output",
|
"vllm::unified_mla_attention_with_output",
|
||||||
"vllm.mamba_mixer2",
|
"vllm::mamba_mixer2",
|
||||||
"vllm.mamba_mixer",
|
"vllm::mamba_mixer",
|
||||||
"vllm.short_conv",
|
"vllm::short_conv",
|
||||||
"vllm.linear_attention",
|
"vllm::linear_attention",
|
||||||
"vllm.plamo2_mamba_mixer",
|
"vllm::plamo2_mamba_mixer",
|
||||||
"vllm.gdn_attention",
|
"vllm::gdn_attention",
|
||||||
"vllm.sparse_attn_indexer",
|
"vllm::sparse_attn_indexer",
|
||||||
]
|
]
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
@ -654,31 +670,25 @@ class CompilationConfig:
|
|||||||
|
|
||||||
def set_splitting_ops_for_inductor_graph_partition(self):
|
def set_splitting_ops_for_inductor_graph_partition(self):
|
||||||
assert self.use_inductor_graph_partition
|
assert self.use_inductor_graph_partition
|
||||||
use_inductor_graph_partition_msg = (
|
if self.splitting_ops is None:
|
||||||
"When use_inductor_graph_partition=True, splitting_ops "
|
self.splitting_ops = list(self._attention_ops)
|
||||||
"are ignored and set to an empty list. Instead, "
|
|
||||||
'"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
|
|
||||||
"used to annotate custom ops for graph partition."
|
|
||||||
)
|
|
||||||
if self.splitting_ops is not None and len(self.splitting_ops) > 0:
|
|
||||||
logger.warning_once(use_inductor_graph_partition_msg)
|
|
||||||
self.splitting_ops = []
|
|
||||||
|
|
||||||
def set_splitting_ops_for_attn_fusion(self):
|
def set_splitting_ops_for_attn_fusion(self):
|
||||||
assert self.pass_config.enable_attn_fusion
|
assert self.pass_config.enable_attn_fusion
|
||||||
if self.splitting_ops is None:
|
# For dynamo-partition (non-inductor) attention fusion,
|
||||||
self.splitting_ops = []
|
# set splitting_ops to empty to avoid splitting at attention ops
|
||||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
self.splitting_ops = []
|
||||||
logger.warning_once(
|
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||||
"enable_attn_fusion is incompatible with piecewise "
|
logger.warning_once(
|
||||||
"cudagraph when use_inductor_graph_partition is off."
|
"enable_attn_fusion is incompatible with piecewise "
|
||||||
"In this case, splitting_ops will be set to empty "
|
"cudagraph when use_inductor_graph_partition is off. "
|
||||||
"list, and cudagraph_mode will be set to FULL. "
|
"In this case, splitting_ops will be set to empty "
|
||||||
"Please ensure you are using attention backends that "
|
"list, and cudagraph_mode will be set to FULL. "
|
||||||
"support cudagraph or set cudagraph_mode to NONE "
|
"Please ensure you are using attention backends that "
|
||||||
"explicitly if encountering any problems."
|
"support cudagraph or set cudagraph_mode to NONE "
|
||||||
)
|
"explicitly if encountering any problems."
|
||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
)
|
||||||
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
assert not self.splitting_ops_contain_attention(), (
|
assert not self.splitting_ops_contain_attention(), (
|
||||||
"attention ops should not be in splitting_ops "
|
"attention ops should not be in splitting_ops "
|
||||||
@ -691,23 +701,17 @@ class CompilationConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def is_attention_compiled_piecewise(self) -> bool:
|
def is_attention_compiled_piecewise(self) -> bool:
|
||||||
use_fx_graph_piecewise_compilation = (
|
if not self.splitting_ops_contain_attention():
|
||||||
self.level == CompilationLevel.PIECEWISE
|
return False
|
||||||
and self.splitting_ops_contain_attention()
|
|
||||||
)
|
|
||||||
|
|
||||||
inductor_used = (
|
if not self.use_inductor_graph_partition:
|
||||||
self.level == CompilationLevel.PIECEWISE and self.use_inductor
|
# Dynamo-level FX split case
|
||||||
) or (
|
return self.level == CompilationLevel.PIECEWISE
|
||||||
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
|
|
||||||
)
|
|
||||||
use_inductor_piecewise_compilation = (
|
|
||||||
inductor_used
|
|
||||||
and self.use_inductor_graph_partition
|
|
||||||
and not self.splitting_ops_contain_attention()
|
|
||||||
)
|
|
||||||
|
|
||||||
return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation
|
# Inductor partition case
|
||||||
|
return (
|
||||||
|
self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor"
|
||||||
|
)
|
||||||
|
|
||||||
def custom_op_log_check(self):
|
def custom_op_log_check(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user