mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 02:45:01 +08:00
remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
parent
1aaecda078
commit
b158df2813
@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
|
|||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_operator_overload():
|
def test_should_split():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.compilation.partition_rules import resolve_defined_ops
|
from vllm.compilation.partition_rules import should_split
|
||||||
|
|
||||||
# Test valid operator names
|
graph = torch.fx.Graph()
|
||||||
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
|
node = torch.fx.Node(
|
||||||
assert len(resolved) == 2
|
graph=graph,
|
||||||
assert resolved[0] is torch.ops.aten.mm.default
|
name="dummy_node",
|
||||||
assert resolved[1] is torch.ops.aten.addmm.default
|
op="call_function",
|
||||||
|
target=torch.ops.aten.add.default,
|
||||||
# Test that invalid operators are skipped (not raising exceptions)
|
args=(),
|
||||||
resolved = resolve_defined_ops(
|
kwargs={},
|
||||||
[
|
|
||||||
"aten::mm.default",
|
|
||||||
"aten::nonexistent_op.default", # This should be skipped
|
|
||||||
"aten::addmm.default",
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
assert len(resolved) == 2 # Only 2 valid ops
|
|
||||||
assert resolved[0] is torch.ops.aten.mm.default
|
# supports OpOverloadPacket
|
||||||
assert resolved[1] is torch.ops.aten.addmm.default
|
splitting_ops = ["aten::add"]
|
||||||
|
assert should_split(node, splitting_ops)
|
||||||
|
|
||||||
|
# supports OpOverload
|
||||||
|
splitting_ops = ["aten::add.default"]
|
||||||
|
assert should_split(node, splitting_ops)
|
||||||
|
|
||||||
|
# supports OpOverload
|
||||||
|
splitting_ops = ["aten::add.Tensor"]
|
||||||
|
assert not should_split(node, splitting_ops)
|
||||||
|
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"silly::attention",
|
||||||
|
mutates_args=["out"],
|
||||||
|
)
|
||||||
|
def attention(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
out.copy_(q + k + v)
|
||||||
|
|
||||||
|
q, k, v, out = [torch.randn(1)] * 4
|
||||||
|
|
||||||
|
# supports custom ops as OpOverloadPacket
|
||||||
|
node = torch.fx.Node(
|
||||||
|
graph=graph,
|
||||||
|
name="dummy_node",
|
||||||
|
op="call_function",
|
||||||
|
target=torch.ops.silly.attention,
|
||||||
|
args=(q, k, v, out),
|
||||||
|
kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
splitting_ops = ["silly::attention"]
|
||||||
|
assert should_split(node, splitting_ops)
|
||||||
|
|
||||||
|
# supports custom ops as OpOverload
|
||||||
|
node = torch.fx.Node(
|
||||||
|
graph=graph,
|
||||||
|
name="dummy_node",
|
||||||
|
op="call_function",
|
||||||
|
target=torch.ops.silly.attention.default,
|
||||||
|
args=(q, k, v, out),
|
||||||
|
kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
splitting_ops = ["silly::attention"]
|
||||||
|
assert should_split(node, splitting_ops)
|
||||||
|
|
||||||
|
splitting_ops = ["silly::attention.default"]
|
||||||
|
assert should_split(node, splitting_ops)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.inductor_pass import pass_context
|
from vllm.compilation.inductor_pass import pass_context
|
||||||
from vllm.compilation.partition_rules import (
|
from vllm.compilation.partition_rules import (
|
||||||
inductor_partition_rule_context,
|
inductor_partition_rule_context,
|
||||||
resolve_defined_ops,
|
should_split,
|
||||||
)
|
)
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -303,7 +303,7 @@ class SplitItem:
|
|||||||
|
|
||||||
|
|
||||||
def split_graph(
|
def split_graph(
|
||||||
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
|
graph: fx.GraphModule, splitting_ops: list[str]
|
||||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||||
# split graph by ops
|
# split graph by ops
|
||||||
subgraph_id = 0
|
subgraph_id = 0
|
||||||
@ -312,12 +312,8 @@ def split_graph(
|
|||||||
for node in graph.graph.nodes:
|
for node in graph.graph.nodes:
|
||||||
if node.op in ("output", "placeholder"):
|
if node.op in ("output", "placeholder"):
|
||||||
continue
|
continue
|
||||||
# Match node.target against resolved_ops
|
|
||||||
# node.target can be OpOverloadPacket, need to check .default
|
if should_split(node, splitting_ops):
|
||||||
if node.op == "call_function" and (
|
|
||||||
node.target in resolved_ops
|
|
||||||
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
|
|
||||||
):
|
|
||||||
subgraph_id += 1
|
subgraph_id += 1
|
||||||
node_to_subgraph_id[node] = subgraph_id
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
split_op_graphs.append(subgraph_id)
|
split_op_graphs.append(subgraph_id)
|
||||||
@ -653,8 +649,7 @@ class VllmBackend:
|
|||||||
else:
|
else:
|
||||||
fx_split_ops = self.compilation_config.splitting_ops or []
|
fx_split_ops = self.compilation_config.splitting_ops or []
|
||||||
|
|
||||||
resolved_split_ops = resolve_defined_ops(fx_split_ops)
|
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
|
||||||
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
|
|
||||||
|
|
||||||
from torch._dynamo.utils import lazy_format_graph_code
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
|
|
||||||
|
|||||||
@ -2,54 +2,39 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._library.utils import lookup_op
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
|
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||||
"""Resolve operator names to OpOverload objects.
|
"""
|
||||||
|
Check if a node should be split for dynamo graph partition.
|
||||||
Skips operators that fail to resolve (e.g., operators not registered or
|
It operates on dynamo graph, so the node.target can be anything.
|
||||||
model-specific operators not present in the current model).
|
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||||
|
|
||||||
Note: Users should inspect the operator graph before lowering and ensure
|
|
||||||
the specified operators are present in the final graph. Built-in PyTorch
|
|
||||||
operators (aten::*, torch::*) may be decomposed, fused, or transformed
|
|
||||||
during Inductor's compilation passes, so use them with caution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op_names: List of operator names in PyTorch format
|
|
||||||
(e.g., "vllm::unified_attention")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of successfully resolved operator overloads
|
|
||||||
"""
|
"""
|
||||||
resolved = []
|
|
||||||
for op_name in op_names:
|
|
||||||
try:
|
|
||||||
resolved.append(lookup_op(op_name))
|
|
||||||
except Exception:
|
|
||||||
# Skip operators that don't exist (e.g., model-specific ops)
|
|
||||||
# Do not warn for attention ops, warn for others
|
|
||||||
# (most likely manually specified)
|
|
||||||
from vllm.config import CompilationConfig
|
|
||||||
|
|
||||||
logger.log(
|
if node.op != "call_function":
|
||||||
logging.DEBUG
|
return False
|
||||||
if op_name in CompilationConfig._attention_ops
|
|
||||||
else logging.WARNING,
|
|
||||||
"Failed to resolve operator for CUDAGraph partition: %s",
|
|
||||||
op_name,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return resolved
|
target = node.target
|
||||||
|
|
||||||
|
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||||
|
# Example: "aten::add"
|
||||||
|
return target._qualified_op_name in splitting_ops
|
||||||
|
|
||||||
|
if isinstance(target, torch._ops.OpOverload):
|
||||||
|
# Example: "aten::add"
|
||||||
|
packet_name = target.name()
|
||||||
|
|
||||||
|
# Example: "aten::add.default"
|
||||||
|
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||||
|
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user