mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
parent
1aaecda078
commit
b158df2813
@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_resolve_operator_overload():
|
||||
def test_should_split():
|
||||
import torch
|
||||
|
||||
from vllm.compilation.partition_rules import resolve_defined_ops
|
||||
from vllm.compilation.partition_rules import should_split
|
||||
|
||||
# Test valid operator names
|
||||
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
|
||||
assert len(resolved) == 2
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
# Test that invalid operators are skipped (not raising exceptions)
|
||||
resolved = resolve_defined_ops(
|
||||
[
|
||||
"aten::mm.default",
|
||||
"aten::nonexistent_op.default", # This should be skipped
|
||||
"aten::addmm.default",
|
||||
]
|
||||
graph = torch.fx.Graph()
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.default,
|
||||
args=(),
|
||||
kwargs={},
|
||||
)
|
||||
assert len(resolved) == 2 # Only 2 valid ops
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
# supports OpOverloadPacket
|
||||
splitting_ops = ["aten::add"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports OpOverload
|
||||
splitting_ops = ["aten::add.default"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports OpOverload
|
||||
splitting_ops = ["aten::add.Tensor"]
|
||||
assert not should_split(node, splitting_ops)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"silly::attention",
|
||||
mutates_args=["out"],
|
||||
)
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
out.copy_(q + k + v)
|
||||
|
||||
q, k, v, out = [torch.randn(1)] * 4
|
||||
|
||||
# supports custom ops as OpOverloadPacket
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.silly.attention,
|
||||
args=(q, k, v, out),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
splitting_ops = ["silly::attention"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports custom ops as OpOverload
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.silly.attention.default,
|
||||
args=(q, k, v, out),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
splitting_ops = ["silly::attention"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
splitting_ops = ["silly::attention.default"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@ -19,7 +19,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
from vllm.compilation.partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
resolve_defined_ops,
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -303,7 +303,7 @@ class SplitItem:
|
||||
|
||||
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
|
||||
graph: fx.GraphModule, splitting_ops: list[str]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
@ -312,12 +312,8 @@ def split_graph(
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
# Match node.target against resolved_ops
|
||||
# node.target can be OpOverloadPacket, need to check .default
|
||||
if node.op == "call_function" and (
|
||||
node.target in resolved_ops
|
||||
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
|
||||
):
|
||||
|
||||
if should_split(node, splitting_ops):
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
@ -653,8 +649,7 @@ class VllmBackend:
|
||||
else:
|
||||
fx_split_ops = self.compilation_config.splitting_ops or []
|
||||
|
||||
resolved_split_ops = resolve_defined_ops(fx_split_ops)
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
|
||||
@ -2,54 +2,39 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch._library.utils import lookup_op
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
|
||||
"""Resolve operator names to OpOverload objects.
|
||||
|
||||
Skips operators that fail to resolve (e.g., operators not registered or
|
||||
model-specific operators not present in the current model).
|
||||
|
||||
Note: Users should inspect the operator graph before lowering and ensure
|
||||
the specified operators are present in the final graph. Built-in PyTorch
|
||||
operators (aten::*, torch::*) may be decomposed, fused, or transformed
|
||||
during Inductor's compilation passes, so use them with caution.
|
||||
|
||||
Args:
|
||||
op_names: List of operator names in PyTorch format
|
||||
(e.g., "vllm::unified_attention")
|
||||
|
||||
Returns:
|
||||
List of successfully resolved operator overloads
|
||||
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
"""
|
||||
Check if a node should be split for dynamo graph partition.
|
||||
It operates on dynamo graph, so the node.target can be anything.
|
||||
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||
"""
|
||||
resolved = []
|
||||
for op_name in op_names:
|
||||
try:
|
||||
resolved.append(lookup_op(op_name))
|
||||
except Exception:
|
||||
# Skip operators that don't exist (e.g., model-specific ops)
|
||||
# Do not warn for attention ops, warn for others
|
||||
# (most likely manually specified)
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
logger.log(
|
||||
logging.DEBUG
|
||||
if op_name in CompilationConfig._attention_ops
|
||||
else logging.WARNING,
|
||||
"Failed to resolve operator for CUDAGraph partition: %s",
|
||||
op_name,
|
||||
)
|
||||
continue
|
||||
if node.op != "call_function":
|
||||
return False
|
||||
|
||||
return resolved
|
||||
target = node.target
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
# Example: "aten::add"
|
||||
return target._qualified_op_name in splitting_ops
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Example: "aten::add"
|
||||
packet_name = target.name()
|
||||
|
||||
# Example: "aten::add.default"
|
||||
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user