remove resolve_op_overloads and use splitting_ops directly (#28081)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng 2025-11-07 17:13:13 -08:00 committed by GitHub
parent 1aaecda078
commit b158df2813
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 65 deletions

View File

@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload(): def test_should_split():
import torch import torch
from vllm.compilation.partition_rules import resolve_defined_ops from vllm.compilation.partition_rules import should_split
# Test valid operator names graph = torch.fx.Graph()
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"]) node = torch.fx.Node(
assert len(resolved) == 2 graph=graph,
assert resolved[0] is torch.ops.aten.mm.default name="dummy_node",
assert resolved[1] is torch.ops.aten.addmm.default op="call_function",
target=torch.ops.aten.add.default,
# Test that invalid operators are skipped (not raising exceptions) args=(),
resolved = resolve_defined_ops( kwargs={},
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
) )
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default # supports OpOverloadPacket
assert resolved[1] is torch.ops.aten.addmm.default splitting_ops = ["aten::add"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.default"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.Tensor"]
assert not should_split(node, splitting_ops)
@torch.library.custom_op(
"silly::attention",
mutates_args=["out"],
)
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
out.copy_(q + k + v)
q, k, v, out = [torch.randn(1)] * 4
# supports custom ops as OpOverloadPacket
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
# supports custom ops as OpOverload
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention.default,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
splitting_ops = ["silly::attention.default"]
assert should_split(node, splitting_ops)
@pytest.mark.skipif( @pytest.mark.skipif(

View File

@ -19,7 +19,7 @@ import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import ( from vllm.compilation.partition_rules import (
inductor_partition_rule_context, inductor_partition_rule_context,
resolve_defined_ops, should_split,
) )
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -303,7 +303,7 @@ class SplitItem:
def split_graph( def split_graph(
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload] graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]: ) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops # split graph by ops
subgraph_id = 0 subgraph_id = 0
@ -312,12 +312,8 @@ def split_graph(
for node in graph.graph.nodes: for node in graph.graph.nodes:
if node.op in ("output", "placeholder"): if node.op in ("output", "placeholder"):
continue continue
# Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default if should_split(node, splitting_ops):
if node.op == "call_function" and (
node.target in resolved_ops
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
):
subgraph_id += 1 subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id) split_op_graphs.append(subgraph_id)
@ -653,8 +649,7 @@ class VllmBackend:
else: else:
fx_split_ops = self.compilation_config.splitting_ops or [] fx_split_ops = self.compilation_config.splitting_ops or []
resolved_split_ops = resolve_defined_ops(fx_split_ops) self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
from torch._dynamo.utils import lazy_format_graph_code from torch._dynamo.utils import lazy_format_graph_code

View File

@ -2,54 +2,39 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import logging
import torch import torch
from torch._library.utils import lookup_op
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]: def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
"""Resolve operator names to OpOverload objects. """
Check if a node should be split for dynamo graph partition.
Skips operators that fail to resolve (e.g., operators not registered or It operates on dynamo graph, so the node.target can be anything.
model-specific operators not present in the current model). We need to check and split only on OpOverload and OpOverloadPacket.
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
Args:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
Returns:
List of successfully resolved operator overloads
""" """
resolved = []
for op_name in op_names:
try:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from vllm.config import CompilationConfig
logger.log( if node.op != "call_function":
logging.DEBUG return False
if op_name in CompilationConfig._attention_ops
else logging.WARNING,
"Failed to resolve operator for CUDAGraph partition: %s",
op_name,
)
continue
return resolved target = node.target
if isinstance(target, torch._ops.OpOverloadPacket):
# Example: "aten::add"
return target._qualified_op_name in splitting_ops
if isinstance(target, torch._ops.OpOverload):
# Example: "aten::add"
packet_name = target.name()
# Example: "aten::add.default"
op_overload_name = f"{packet_name}.{target._overloadname}"
return op_overload_name in splitting_ops or packet_name in splitting_ops
return False
@contextlib.contextmanager @contextlib.contextmanager