[Compile] Fix noop_elimination pass and add tests for noop_elimination (#24880)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-09-16 07:33:18 +08:00 committed by GitHub
parent 45bfa49cb8
commit 5bcc153d7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 130 additions and 23 deletions

View File

@ -394,6 +394,7 @@ steps:
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30

View File

@ -64,4 +64,8 @@ class TestBackend:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
def op_count(self, op: OpOverload, before=False) -> int:
graph = self.graph_pre_pass if before else self.graph_post_pass
return len(list(find_op_nodes(op, graph)))

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import vllm
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from .backend import TestBackend
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", [256, 1024])
@pytest.mark.parametrize("hidden_size", [64, 4096])
def test_noop_elimination(dtype, num_tokens, hidden_size):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
class Model(torch.nn.Module):
def forward(self, x):
# Chain of reshapes
y = x.reshape(-1, 128, 32)
z = y.reshape(-1, 4096)
# No-op reshape
a = z.reshape(-1, 4096)
# Final reshape that should remain
b = a.reshape(-1, 128, 32)
# No-op slice
c = b[0:b.shape[0]]
# The pass should replace the result of this op with `c`.
d = torch.slice_scatter(
torch.ones_like(c), # Dummy tensor to be scattered into
c, # Source tensor
0, # dim
0, # start
c.shape[0], # end
)
return d
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_noop=True),
))
with vllm.config.set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
backend = TestBackend(noop_pass)
model = Model()
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
ATOL, RTOL = (2e-3, 2e-3)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# The no-op reshape and slice should be eliminated.
# The chain of reshapes should be fused into a single reshape.
assert backend.op_count(torch.ops.aten.reshape.default) == 1
assert backend.op_count(torch.ops.aten.slice.Tensor) == 0
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
def test_non_noop_slice_preserved():
"""Ensure that a slice with end=-1 (dropping last row) is NOT eliminated.
Regression test for a bug where end=-1 was treated like an inferred
dimension (reshape semantics) leading to incorrect elimination.
"""
torch.set_default_device("cuda")
x = torch.randn(16, 16)
class SliceModel(torch.nn.Module):
def forward(self, x):
base = x.clone()
src = torch.ones(15, 16)
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
return x[0:-1, :], y
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
pass_config=PassConfig(enable_noop=True),
))
with vllm.config.set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
backend = TestBackend(noop_pass)
model = SliceModel()
ref = model(x)
compiled = torch.compile(model, backend=backend)
out = compiled(x)
torch.testing.assert_close(ref, out)
# The slice should remain (not a no-op).
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1

View File

@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass):
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
TODO(luka): This is currently tested in test_fusion,
but separate tests could be good.
"""
def __call__(self, graph: torch.fx.Graph):
@ -96,17 +93,19 @@ class NoOpEliminationPass(VllmInductorPass):
# Invalid reshape args, skip
continue
if self.all_dims_equivalent(shape, input_shape):
if self.reshape_all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
# python slicing semantics are different from reshape
# Don't treat -1 as inferred dimension
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
i_dim = input_shape[dim_index]
output_shape = node.meta["val"].shape
if start == 0 and self.dims_equivalent(end, i_dim):
if output_shape == input_shape:
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
@ -116,14 +115,7 @@ class NoOpEliminationPass(VllmInductorPass):
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
view_dim = view_shape[dim_index]
# Check that view fully covers base and the full view is used
# (if the view fully covered the base after slicing but was not
# fully used, we could replace slice_scatter with a simple slice
# but that's a niche case).
if (base_shape == view_shape and start == 0
and self.dims_equivalent(end, view_dim)):
if base_shape == view_shape:
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
@ -132,13 +124,9 @@ class NoOpEliminationPass(VllmInductorPass):
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
i_dims: Iterable[Union[int, SymInt]]):
return all(
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
# ---------------------- Reshape helpers ----------------------
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
@ -156,10 +144,18 @@ class NoOpEliminationPass(VllmInductorPass):
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
def reshape_all_dims_equivalent(
self,
dims: Iterable[Union[int, torch.fx.Node]],
i_dims: Iterable[Union[int, SymInt]],
) -> bool:
return all(
self.reshape_dims_equivalent(s, i_s)
for s, i_s in zip(dims, i_dims))