remove resolve_op_overloads and use splitting_ops directly (#28081)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng 2025-11-07 17:13:13 -08:00 committed by GitHub
parent 1aaecda078
commit b158df2813
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 65 deletions

View File

@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload():
def test_should_split():
import torch
from vllm.compilation.partition_rules import resolve_defined_ops
from vllm.compilation.partition_rules import should_split
# Test valid operator names
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
assert len(resolved) == 2
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# Test that invalid operators are skipped (not raising exceptions)
resolved = resolve_defined_ops(
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
graph = torch.fx.Graph()
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.aten.add.default,
args=(),
kwargs={},
)
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# supports OpOverloadPacket
splitting_ops = ["aten::add"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.default"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.Tensor"]
assert not should_split(node, splitting_ops)
@torch.library.custom_op(
"silly::attention",
mutates_args=["out"],
)
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
out.copy_(q + k + v)
q, k, v, out = [torch.randn(1)] * 4
# supports custom ops as OpOverloadPacket
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
# supports custom ops as OpOverload
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention.default,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
splitting_ops = ["silly::attention.default"]
assert should_split(node, splitting_ops)
@pytest.mark.skipif(

View File

@ -19,7 +19,7 @@ import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
inductor_partition_rule_context,
resolve_defined_ops,
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
@ -303,7 +303,7 @@ class SplitItem:
def split_graph(
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
@ -312,12 +312,8 @@ def split_graph(
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
# Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default
if node.op == "call_function" and (
node.target in resolved_ops
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
):
if should_split(node, splitting_ops):
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
@ -653,8 +649,7 @@ class VllmBackend:
else:
fx_split_ops = self.compilation_config.splitting_ops or []
resolved_split_ops = resolve_defined_ops(fx_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
from torch._dynamo.utils import lazy_format_graph_code

View File

@ -2,54 +2,39 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import logging
import torch
from torch._library.utils import lookup_op
from vllm.logger import init_logger
logger = init_logger(__name__)
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
"""Resolve operator names to OpOverload objects.
Skips operators that fail to resolve (e.g., operators not registered or
model-specific operators not present in the current model).
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
Args:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
Returns:
List of successfully resolved operator overloads
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
"""
Check if a node should be split for dynamo graph partition.
It operates on dynamo graph, so the node.target can be anything.
We need to check and split only on OpOverload and OpOverloadPacket.
"""
resolved = []
for op_name in op_names:
try:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from vllm.config import CompilationConfig
logger.log(
logging.DEBUG
if op_name in CompilationConfig._attention_ops
else logging.WARNING,
"Failed to resolve operator for CUDAGraph partition: %s",
op_name,
)
continue
if node.op != "call_function":
return False
return resolved
target = node.target
if isinstance(target, torch._ops.OpOverloadPacket):
# Example: "aten::add"
return target._qualified_op_name in splitting_ops
if isinstance(target, torch._ops.OpOverload):
# Example: "aten::add"
packet_name = target.name()
# Example: "aten::add.default"
op_overload_name = f"{packet_name}.{target._overloadname}"
return op_overload_name in splitting_ops or packet_name in splitting_ops
return False
@contextlib.contextmanager