From 67661375fad8dbea1d9f3b812c55ec42758cf7d2 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 10 Oct 2025 14:33:04 +0100 Subject: [PATCH] [BugFix] Fix noop elimination edge case (#26394) Signed-off-by: Andy Lo --- tests/compile/test_noop_elimination.py | 15 ++++-- vllm/compilation/noop_elimination.py | 74 ++++++++++---------------- 2 files changed, 39 insertions(+), 50 deletions(-) diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py index fda7f4e3bafa5..188f4514dda5f 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/test_noop_elimination.py @@ -12,15 +12,23 @@ from .backend import TestBackend @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("num_tokens", [256, 1024]) +# Important edge case is when `num_tokens == buffer_size` +@pytest.mark.parametrize( + ("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)] +) @pytest.mark.parametrize("hidden_size", [64, 4096]) -def test_noop_elimination(dtype, num_tokens, hidden_size): +def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype) + def forward(self, x): + x += self.pos_embed[: x.shape[0]] # Chain of reshapes y = x.reshape(-1, 128, 32) z = y.reshape(-1, 4096) @@ -65,9 +73,10 @@ def test_noop_elimination(dtype, num_tokens, hidden_size): torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # The no-op reshape and slice should be eliminated. + # The initial slice on the positional embedding should remain. # The chain of reshapes should be fused into a single reshape. assert backend.op_count(torch.ops.aten.reshape.default) == 1 - assert backend.op_count(torch.ops.aten.slice.Tensor) == 0 + assert backend.op_count(torch.ops.aten.slice.Tensor) == 1 assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 3d807ab3a6de7..45668c7af3151 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -81,49 +81,32 @@ class NoOpEliminationPass(VllmInductorPass): graph.erase_node(input) count += 1 - # Case 2: remove this reshape if it produces the original shape - input, shape = node.args[:2] - input_shape = input.meta["val"].shape - if len(shape) != len(input_shape): - # Reshape changing rank, skip - continue - - if shape.count(-1) > 1: - # Invalid reshape args, skip - continue - - if self.reshape_all_dims_equivalent(shape, input_shape): - node.replace_all_uses_with(input) - graph.erase_node(node) - count += 1 - - elif is_func(node, torch.ops.aten.slice.Tensor): - # python slicing semantics are different from reshape - # Don't treat -1 as inferred dimension - input, dim_index, start, end = node.args[:4] + # remove reshape/slice if it produces the original shape + if is_func(node, torch.ops.aten.reshape.default) or is_func( + node, torch.ops.aten.slice.Tensor + ): + input = node.args[0] input_shape = input.meta["val"].shape output_shape = node.meta["val"].shape - - if output_shape == input_shape: + if self.all_dims_equivalent(input_shape, output_shape): node.replace_all_uses_with(input) graph.erase_node(node) count += 1 - elif is_func(node, torch.ops.aten.slice_scatter.default): base, view, dim_index, start, end = node.args[:5] base_shape = base.meta["val"].shape view_shape = view.meta["val"].shape - if base_shape == view_shape: + if self.all_dims_equivalent(base_shape, view_shape): node.replace_all_uses_with(view) graph.erase_node(node) count += 1 logger.debug("Removed %s no-op reshapes and slices", count) - # ---------------------- Reshape helpers ---------------------- - def reshape_dims_equivalent( - self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt] + # ---------------------- Shape comparison helpers ---------------------- + def dims_equivalent( + self, dim: Union[int, SymInt], i_dim: Union[int, SymInt] ) -> bool: """ This function checks if two dimensions are equivalent. @@ -131,27 +114,24 @@ class NoOpEliminationPass(VllmInductorPass): :param i_dim: The corresponding dimension in the input tensor :return: Are the dimensions equivalent? - There are three cases in which the dimensions are equivalent: + There are two cases in which the dimensions are equivalent: 1. The dimensions are equal (both integers) - 2. The reshape dimension is -1 (i.e. inferred) - 3. The dimensions both correspond to the same SymInt - - While case 2 does not guarantee the dimensions are equal, - they are equal if all other dimensions are equal. - - In case 3, the reshape dimension is a torch.fx.Node, - and its value is a SymInt. That value is equal to the - input dimension. + 2. The dimensions both correspond to the same SymInt """ - # Case 1 and 2 - if dim == i_dim or dim == -1: - return True - # Case 3 - return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim + # Case 1 + if isinstance(i_dim, int) and isinstance(dim, int): + return dim == i_dim + # Case 2 + if isinstance(i_dim, SymInt) and isinstance(dim, SymInt): + return dim == i_dim + return False - def reshape_all_dims_equivalent( - self, - dims: Iterable[Union[int, torch.fx.Node]], - i_dims: Iterable[Union[int, SymInt]], + def all_dims_equivalent( + self, dims: Iterable[Union[int, SymInt]], i_dims: Iterable[Union[int, SymInt]] ) -> bool: - return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + dims_ = list(dims) + i_dims_ = list(i_dims) + if len(dims_) != len(i_dims_): + # Different ranks can't be equivalent + return False + return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))