[BugFix] Fix noop elimination edge case (#26394)

Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
Andy Lo 2025-10-10 14:33:04 +01:00 committed by GitHub
parent 213b64452a
commit 67661375fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 50 deletions

View File

@ -12,15 +12,23 @@ from .backend import TestBackend
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", [256, 1024])
# Important edge case is when `num_tokens == buffer_size`
@pytest.mark.parametrize(
("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)]
)
@pytest.mark.parametrize("hidden_size", [64, 4096])
def test_noop_elimination(dtype, num_tokens, hidden_size):
def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype)
def forward(self, x):
x += self.pos_embed[: x.shape[0]]
# Chain of reshapes
y = x.reshape(-1, 128, 32)
z = y.reshape(-1, 4096)
@ -65,9 +73,10 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# The no-op reshape and slice should be eliminated.
# The initial slice on the positional embedding should remain.
# The chain of reshapes should be fused into a single reshape.
assert backend.op_count(torch.ops.aten.reshape.default) == 1
assert backend.op_count(torch.ops.aten.slice.Tensor) == 0
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0

View File

@ -81,49 +81,32 @@ class NoOpEliminationPass(VllmInductorPass):
graph.erase_node(input)
count += 1
# Case 2: remove this reshape if it produces the original shape
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if self.reshape_all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
# python slicing semantics are different from reshape
# Don't treat -1 as inferred dimension
input, dim_index, start, end = node.args[:4]
# remove reshape/slice if it produces the original shape
if is_func(node, torch.ops.aten.reshape.default) or is_func(
node, torch.ops.aten.slice.Tensor
):
input = node.args[0]
input_shape = input.meta["val"].shape
output_shape = node.meta["val"].shape
if output_shape == input_shape:
if self.all_dims_equivalent(input_shape, output_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
if base_shape == view_shape:
if self.all_dims_equivalent(base_shape, view_shape):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
# ---------------------- Reshape helpers ----------------------
def reshape_dims_equivalent(
self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt]
# ---------------------- Shape comparison helpers ----------------------
def dims_equivalent(
self, dim: Union[int, SymInt], i_dim: Union[int, SymInt]
) -> bool:
"""
This function checks if two dimensions are equivalent.
@ -131,27 +114,24 @@ class NoOpEliminationPass(VllmInductorPass):
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent:
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred)
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
2. The dimensions both correspond to the same SymInt
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
# Case 1
if isinstance(i_dim, int) and isinstance(dim, int):
return dim == i_dim
# Case 2
if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
return dim == i_dim
return False
def reshape_all_dims_equivalent(
self,
dims: Iterable[Union[int, torch.fx.Node]],
i_dims: Iterable[Union[int, SymInt]],
def all_dims_equivalent(
self, dims: Iterable[Union[int, SymInt]], i_dims: Iterable[Union[int, SymInt]]
) -> bool:
return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
dims_ = list(dims)
i_dims_ = list(i_dims)
if len(dims_) != len(i_dims_):
# Different ranks can't be equivalent
return False
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))