From 262d263f6c56fa95e15422d3a475da8efdf67cc1 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Thu, 13 Nov 2025 12:09:05 -0800 Subject: [PATCH] [Bugfix] Eliminate tuple inputs to submodules in graph partitioning (#28533) Signed-off-by: Yanan Cao --- .buildkite/test-pipeline.yaml | 1 + tests/compile/test_graph_partition.py | 124 ++++++++++++++++++++++++++ vllm/compilation/backends.py | 17 +++- 3 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 tests/compile/test_graph_partition.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index be1b79ddc4324..52539728215bb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -445,6 +445,7 @@ steps: - vllm/ - tests/compile commands: + - pytest -v -s compile/test_graph_partition.py - pytest -v -s compile/test_config.py - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py new file mode 100644 index 0000000000000..1cd783843a626 --- /dev/null +++ b/tests/compile/test_graph_partition.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator + +import pytest +import torch +from torch.fx.experimental.proxy_tensor import make_fx + +from vllm.compilation.backends import split_graph + + +def test_getitem_moved_to_producer_subgraph(): + """ + Test that getitem operations are moved to the same subgraph as their input, + preventing tuple inputs to submodules. + """ + + def model_fn(x: torch.Tensor) -> torch.Tensor: + # torch.split returns a tuple, creating real getitem operations + # Should become first submodule that produces tuple + chunks = torch.split(x, x.shape[0] // 2, dim=0) + + # Following ops should become second submodule that consumes tuple + result_0 = torch.relu(chunks[0]) + result_1 = torch.relu(chunks[1]) + return torch.cat([result_0, result_1], dim=0) + + x = torch.randn(4, 3) + gm = make_fx(model_fn)(x) + + has_getitem = any( + node.op == "call_function" and node.target == operator.getitem + for node in gm.graph.nodes + ) + assert has_getitem, "Test setup failed: graph should contain getitem operations" + + # Split on tuple producer aten::split + split_ops = ["aten::split.Tensor"] + split_gm, split_items = split_graph(gm, split_ops) + assert len(split_items) == 2, "Graph should be split into 2 submodules" + + for split_item in split_items: + submodule = split_item.graph + + getitem_on_placeholder = [] + for node in submodule.graph.nodes: + if ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].op == "placeholder" + ): + getitem_on_placeholder.append(node) + + assert len(getitem_on_placeholder) == 0, ( + f"Submodule {split_item.submod_name} has getitem operations on " + f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. " + "This means tuple inputs were not properly eliminated." + ) + + new_x = torch.randn(4, 3) + output_original = gm(new_x) + output_split = split_gm(new_x) + + assert torch.allclose(output_original, output_split), "Output mismatch" + + +def test_no_tuple_inputs_with_multiple_consumers(): + """ + Test that when a tuple is consumed by multiple split operations, + getitem operations are properly moved to avoid tuple inputs. + """ + + def model_fn(x: torch.Tensor) -> torch.Tensor: + # torch.split returns a tuple, creating real getitem operations + # Should become first submodule that produces tuple + chunks = torch.split(x, x.shape[0] // 2, dim=0) + + # These should become second submodule consuming tuple + result_1 = torch.relu(chunks[0]) + result_2 = torch.relu(chunks[1]) + + # Artificial graph splitting point to create another + # independent submodule that consumes tuple later + # This would become the third submodule + result_1 = torch.sigmoid(result_1) + + # Fourth submodule that consumes tuple + result = torch.cat([chunks[0], chunks[1], result_1, result_2]) + return result + + x = torch.randn(4, 3) + gm = make_fx(model_fn)(x) + + has_getitem = any( + node.op == "call_function" and node.target == operator.getitem + for node in gm.graph.nodes + ) + assert has_getitem, "Test setup failed: graph should contain getitem operations" + + split_ops = ["aten::split.Tensor", "aten::sigmoid"] + split_gm, split_items = split_graph(gm, split_ops) + assert len(split_items) == 4, "Graph should be split into 4 submodules" + + for split_item in split_items: + submodule = split_item.graph + + for node in submodule.graph.nodes: + if ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].op == "placeholder" + ): + pytest.fail( + f"Submodule {split_item.submod_name} has getitem on " + f"placeholder {node.args[0].name}, indicating it receives " + "a tuple input" + ) + + new_x = torch.randn(4, 3) + output_original = gm(new_x) + output_split = split_gm(new_x) + + assert torch.allclose(output_original, output_split), "Output mismatch after split" diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index be69075f94f09..60ef6eef21663 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -4,6 +4,7 @@ import ast import dataclasses import hashlib +import operator import os import pprint import time @@ -307,12 +308,24 @@ def split_graph( ) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 - node_to_subgraph_id = {} - split_op_graphs = [] + node_to_subgraph_id: dict[fx.Node, int] = {} + split_op_graphs: list[int] = [] for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue + # Check if this is a getitem operation on a node from an earlier subgraph. + # If so, assign it to the same subgraph as its input to avoid passing entire + # tuple as input to submodules, which is against standalone_compile and + # AoTAutograd input requirement. + if node.op == "call_function" and node.target == operator.getitem: + # Assign this getitem to the same subgraph as its input + input_node = node.args[0] + if input_node.op != "placeholder": + assert input_node in node_to_subgraph_id + node_to_subgraph_id[node] = node_to_subgraph_id[input_node] + continue + if should_split(node, splitting_ops): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id