From 5bcc153d7bf69ef34bc5788a33f60f1792cf2861 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Tue, 16 Sep 2025 07:33:18 +0800 Subject: [PATCH] [Compile] Fix noop_elimination pass and add tests for noop_elimination (#24880) Signed-off-by: zjy0516 --- .buildkite/test-pipeline.yaml | 1 + tests/compile/backend.py | 6 +- tests/compile/test_noop_elimination.py | 106 +++++++++++++++++++++++++ vllm/compilation/noop_elimination.py | 40 +++++----- 4 files changed, 130 insertions(+), 23 deletions(-) create mode 100644 tests/compile/test_noop_elimination.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index adb5c862eecd9..c3b5aa2907a49 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -394,6 +394,7 @@ steps: - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py + - pytest -v -s compile/test_noop_elimination.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ace4d25534cdd..2c4287950dcfe 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -64,4 +64,8 @@ class TestBackend: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass))) assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" - assert num_post > 0, f"Op {op.name()} not found in post-pass graph" \ No newline at end of file + assert num_post > 0, f"Op {op.name()} not found in post-pass graph" + + def op_count(self, op: OpOverload, before=False) -> int: + graph = self.graph_pre_pass if before else self.graph_post_pass + return len(list(find_op_nodes(op, graph))) diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py new file mode 100644 index 0000000000000..242d531312675 --- /dev/null +++ b/tests/compile/test_noop_elimination.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) + +from .backend import TestBackend + + +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("num_tokens", [256, 1024]) +@pytest.mark.parametrize("hidden_size", [64, 4096]) +def test_noop_elimination(dtype, num_tokens, hidden_size): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + class Model(torch.nn.Module): + + def forward(self, x): + # Chain of reshapes + y = x.reshape(-1, 128, 32) + z = y.reshape(-1, 4096) + # No-op reshape + a = z.reshape(-1, 4096) + # Final reshape that should remain + b = a.reshape(-1, 128, 32) + # No-op slice + c = b[0:b.shape[0]] + # The pass should replace the result of this op with `c`. + d = torch.slice_scatter( + torch.ones_like(c), # Dummy tensor to be scattered into + c, # Source tensor + 0, # dim + 0, # start + c.shape[0], # end + ) + return d + + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + )) + with vllm.config.set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + + backend = TestBackend(noop_pass) + + model = Model() + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + ATOL, RTOL = (2e-3, 2e-3) + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + # The no-op reshape and slice should be eliminated. + # The chain of reshapes should be fused into a single reshape. + assert backend.op_count(torch.ops.aten.reshape.default) == 1 + assert backend.op_count(torch.ops.aten.slice.Tensor) == 0 + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 + + +def test_non_noop_slice_preserved(): + """Ensure that a slice with end=-1 (dropping last row) is NOT eliminated. + + Regression test for a bug where end=-1 was treated like an inferred + dimension (reshape semantics) leading to incorrect elimination. + """ + torch.set_default_device("cuda") + x = torch.randn(16, 16) + + class SliceModel(torch.nn.Module): + + def forward(self, x): + base = x.clone() + src = torch.ones(15, 16) + y = torch.slice_scatter(base, src, dim=0, start=0, end=-1) + return x[0:-1, :], y + + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + )) + with vllm.config.set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + backend = TestBackend(noop_pass) + model = SliceModel() + ref = model(x) + compiled = torch.compile(model, backend=backend) + out = compiled(x) + torch.testing.assert_close(ref, out) + # The slice should remain (not a no-op). + assert backend.op_count(torch.ops.aten.slice.Tensor) == 1 + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1 diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 4888d4d1298e3..17e85e70218da 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass): scaled_mm: "f16[s0, 4096]" = ... at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) out: "f16[s0, 4096]" = at[1] - - TODO(luka): This is currently tested in test_fusion, - but separate tests could be good. """ def __call__(self, graph: torch.fx.Graph): @@ -96,17 +93,19 @@ class NoOpEliminationPass(VllmInductorPass): # Invalid reshape args, skip continue - if self.all_dims_equivalent(shape, input_shape): + if self.reshape_all_dims_equivalent(shape, input_shape): node.replace_all_uses_with(input) graph.erase_node(node) count += 1 elif is_func(node, torch.ops.aten.slice.Tensor): + # python slicing semantics are different from reshape + # Don't treat -1 as inferred dimension input, dim_index, start, end = node.args[:4] input_shape = input.meta["val"].shape - i_dim = input_shape[dim_index] + output_shape = node.meta["val"].shape - if start == 0 and self.dims_equivalent(end, i_dim): + if output_shape == input_shape: node.replace_all_uses_with(input) graph.erase_node(node) count += 1 @@ -116,14 +115,7 @@ class NoOpEliminationPass(VllmInductorPass): base_shape = base.meta["val"].shape view_shape = view.meta["val"].shape - view_dim = view_shape[dim_index] - - # Check that view fully covers base and the full view is used - # (if the view fully covered the base after slicing but was not - # fully used, we could replace slice_scatter with a simple slice - # but that's a niche case). - if (base_shape == view_shape and start == 0 - and self.dims_equivalent(end, view_dim)): + if base_shape == view_shape: node.replace_all_uses_with(view) graph.erase_node(node) count += 1 @@ -132,13 +124,9 @@ class NoOpEliminationPass(VllmInductorPass): self.dump_graph(graph, "after_noop_elimination") self.end_and_log() - def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], - i_dims: Iterable[Union[int, SymInt]]): - return all( - self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) - - def dims_equivalent(self, dim: Union[int, torch.fx.Node], - i_dim: Union[int, SymInt]) -> bool: + # ---------------------- Reshape helpers ---------------------- + def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: """ This function checks if two dimensions are equivalent. :param dim: The dimension arg to reshape/slice @@ -156,10 +144,18 @@ class NoOpEliminationPass(VllmInductorPass): In case 3, the reshape dimension is a torch.fx.Node, and its value is a SymInt. That value is equal to the input dimension. - """ # Case 1 and 2 if dim == i_dim or dim == -1: return True # Case 3 return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim + + def reshape_all_dims_equivalent( + self, + dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]], + ) -> bool: + return all( + self.reshape_dims_equivalent(s, i_s) + for s, i_s in zip(dims, i_dims))