[Bugfix] Eliminate tuple inputs to submodules in graph partitioning (#28533)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao 2025-11-13 12:09:05 -08:00 committed by GitHub
parent 968060c15a
commit 262d263f6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 2 deletions

View File

@ -445,6 +445,7 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_graph_partition.py
- pytest -v -s compile/test_config.py
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py

View File

@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
import pytest
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.compilation.backends import split_graph
def test_getitem_moved_to_producer_subgraph():
"""
Test that getitem operations are moved to the same subgraph as their input,
preventing tuple inputs to submodules.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks = torch.split(x, x.shape[0] // 2, dim=0)
# Following ops should become second submodule that consumes tuple
result_0 = torch.relu(chunks[0])
result_1 = torch.relu(chunks[1])
return torch.cat([result_0, result_1], dim=0)
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
has_getitem = any(
node.op == "call_function" and node.target == operator.getitem
for node in gm.graph.nodes
)
assert has_getitem, "Test setup failed: graph should contain getitem operations"
# Split on tuple producer aten::split
split_ops = ["aten::split.Tensor"]
split_gm, split_items = split_graph(gm, split_ops)
assert len(split_items) == 2, "Graph should be split into 2 submodules"
for split_item in split_items:
submodule = split_item.graph
getitem_on_placeholder = []
for node in submodule.graph.nodes:
if (
node.op == "call_function"
and node.target == operator.getitem
and node.args[0].op == "placeholder"
):
getitem_on_placeholder.append(node)
assert len(getitem_on_placeholder) == 0, (
f"Submodule {split_item.submod_name} has getitem operations on "
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
"This means tuple inputs were not properly eliminated."
)
new_x = torch.randn(4, 3)
output_original = gm(new_x)
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), "Output mismatch"
def test_no_tuple_inputs_with_multiple_consumers():
"""
Test that when a tuple is consumed by multiple split operations,
getitem operations are properly moved to avoid tuple inputs.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks = torch.split(x, x.shape[0] // 2, dim=0)
# These should become second submodule consuming tuple
result_1 = torch.relu(chunks[0])
result_2 = torch.relu(chunks[1])
# Artificial graph splitting point to create another
# independent submodule that consumes tuple later
# This would become the third submodule
result_1 = torch.sigmoid(result_1)
# Fourth submodule that consumes tuple
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
return result
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
has_getitem = any(
node.op == "call_function" and node.target == operator.getitem
for node in gm.graph.nodes
)
assert has_getitem, "Test setup failed: graph should contain getitem operations"
split_ops = ["aten::split.Tensor", "aten::sigmoid"]
split_gm, split_items = split_graph(gm, split_ops)
assert len(split_items) == 4, "Graph should be split into 4 submodules"
for split_item in split_items:
submodule = split_item.graph
for node in submodule.graph.nodes:
if (
node.op == "call_function"
and node.target == operator.getitem
and node.args[0].op == "placeholder"
):
pytest.fail(
f"Submodule {split_item.submod_name} has getitem on "
f"placeholder {node.args[0].name}, indicating it receives "
"a tuple input"
)
new_x = torch.randn(4, 3)
output_original = gm(new_x)
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"

View File

@ -4,6 +4,7 @@
import ast
import dataclasses
import hashlib
import operator
import os
import pprint
import time
@ -307,12 +308,24 @@ def split_graph(
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
node_to_subgraph_id = {}
split_op_graphs = []
node_to_subgraph_id: dict[fx.Node, int] = {}
split_op_graphs: list[int] = []
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
# Check if this is a getitem operation on a node from an earlier subgraph.
# If so, assign it to the same subgraph as its input to avoid passing entire
# tuple as input to submodules, which is against standalone_compile and
# AoTAutograd input requirement.
if node.op == "call_function" and node.target == operator.getitem:
# Assign this getitem to the same subgraph as its input
input_node = node.args[0]
if input_node.op != "placeholder":
assert input_node in node_to_subgraph_id
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
continue
if should_split(node, splitting_ops):
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id