diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 4145e84c2ee0c..7455147f2b95a 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -214,28 +214,72 @@ def test_splitting_ops_dynamic(): assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE -def test_resolve_operator_overload(): +def test_should_split(): import torch - from vllm.compilation.partition_rules import resolve_defined_ops + from vllm.compilation.partition_rules import should_split - # Test valid operator names - resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"]) - assert len(resolved) == 2 - assert resolved[0] is torch.ops.aten.mm.default - assert resolved[1] is torch.ops.aten.addmm.default - - # Test that invalid operators are skipped (not raising exceptions) - resolved = resolve_defined_ops( - [ - "aten::mm.default", - "aten::nonexistent_op.default", # This should be skipped - "aten::addmm.default", - ] + graph = torch.fx.Graph() + node = torch.fx.Node( + graph=graph, + name="dummy_node", + op="call_function", + target=torch.ops.aten.add.default, + args=(), + kwargs={}, ) - assert len(resolved) == 2 # Only 2 valid ops - assert resolved[0] is torch.ops.aten.mm.default - assert resolved[1] is torch.ops.aten.addmm.default + + # supports OpOverloadPacket + splitting_ops = ["aten::add"] + assert should_split(node, splitting_ops) + + # supports OpOverload + splitting_ops = ["aten::add.default"] + assert should_split(node, splitting_ops) + + # supports OpOverload + splitting_ops = ["aten::add.Tensor"] + assert not should_split(node, splitting_ops) + + @torch.library.custom_op( + "silly::attention", + mutates_args=["out"], + ) + def attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor + ) -> None: + out.copy_(q + k + v) + + q, k, v, out = [torch.randn(1)] * 4 + + # supports custom ops as OpOverloadPacket + node = torch.fx.Node( + graph=graph, + name="dummy_node", + op="call_function", + target=torch.ops.silly.attention, + args=(q, k, v, out), + kwargs={}, + ) + + splitting_ops = ["silly::attention"] + assert should_split(node, splitting_ops) + + # supports custom ops as OpOverload + node = torch.fx.Node( + graph=graph, + name="dummy_node", + op="call_function", + target=torch.ops.silly.attention.default, + args=(q, k, v, out), + kwargs={}, + ) + + splitting_ops = ["silly::attention"] + assert should_split(node, splitting_ops) + + splitting_ops = ["silly::attention.default"] + assert should_split(node, splitting_ops) @pytest.mark.skipif( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 38300bebb8705..be69075f94f09 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -19,7 +19,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import pass_context from vllm.compilation.partition_rules import ( inductor_partition_rule_context, - resolve_defined_ops, + should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger @@ -303,7 +303,7 @@ class SplitItem: def split_graph( - graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload] + graph: fx.GraphModule, splitting_ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 @@ -312,12 +312,8 @@ def split_graph( for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - # Match node.target against resolved_ops - # node.target can be OpOverloadPacket, need to check .default - if node.op == "call_function" and ( - node.target in resolved_ops - or (hasattr(node.target, "default") and node.target.default in resolved_ops) - ): + + if should_split(node, splitting_ops): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) @@ -653,8 +649,7 @@ class VllmBackend: else: fx_split_ops = self.compilation_config.splitting_ops or [] - resolved_split_ops = resolve_defined_ops(fx_split_ops) - self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops) + self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops) from torch._dynamo.utils import lazy_format_graph_code diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py index 094b86dcb4aa2..08bd27e809526 100644 --- a/vllm/compilation/partition_rules.py +++ b/vllm/compilation/partition_rules.py @@ -2,54 +2,39 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import logging import torch -from torch._library.utils import lookup_op from vllm.logger import init_logger logger = init_logger(__name__) -def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]: - """Resolve operator names to OpOverload objects. - - Skips operators that fail to resolve (e.g., operators not registered or - model-specific operators not present in the current model). - - Note: Users should inspect the operator graph before lowering and ensure - the specified operators are present in the final graph. Built-in PyTorch - operators (aten::*, torch::*) may be decomposed, fused, or transformed - during Inductor's compilation passes, so use them with caution. - - Args: - op_names: List of operator names in PyTorch format - (e.g., "vllm::unified_attention") - - Returns: - List of successfully resolved operator overloads +def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool: + """ + Check if a node should be split for dynamo graph partition. + It operates on dynamo graph, so the node.target can be anything. + We need to check and split only on OpOverload and OpOverloadPacket. """ - resolved = [] - for op_name in op_names: - try: - resolved.append(lookup_op(op_name)) - except Exception: - # Skip operators that don't exist (e.g., model-specific ops) - # Do not warn for attention ops, warn for others - # (most likely manually specified) - from vllm.config import CompilationConfig - logger.log( - logging.DEBUG - if op_name in CompilationConfig._attention_ops - else logging.WARNING, - "Failed to resolve operator for CUDAGraph partition: %s", - op_name, - ) - continue + if node.op != "call_function": + return False - return resolved + target = node.target + + if isinstance(target, torch._ops.OpOverloadPacket): + # Example: "aten::add" + return target._qualified_op_name in splitting_ops + + if isinstance(target, torch._ops.OpOverload): + # Example: "aten::add" + packet_name = target.name() + + # Example: "aten::add.default" + op_overload_name = f"{packet_name}.{target._overloadname}" + return op_overload_name in splitting_ops or packet_name in splitting_ops + + return False @contextlib.contextmanager