mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 09:24:28 +08:00
[BugFix] Fix noop elimination edge case (#26394)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
parent
213b64452a
commit
67661375fa
@ -12,15 +12,23 @@ from .backend import TestBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
||||
# Important edge case is when `num_tokens == buffer_size`
|
||||
@pytest.mark.parametrize(
|
||||
("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)]
|
||||
)
|
||||
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x += self.pos_embed[: x.shape[0]]
|
||||
# Chain of reshapes
|
||||
y = x.reshape(-1, 128, 32)
|
||||
z = y.reshape(-1, 4096)
|
||||
@ -65,9 +73,10 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# The no-op reshape and slice should be eliminated.
|
||||
# The initial slice on the positional embedding should remain.
|
||||
# The chain of reshapes should be fused into a single reshape.
|
||||
assert backend.op_count(torch.ops.aten.reshape.default) == 1
|
||||
assert backend.op_count(torch.ops.aten.slice.Tensor) == 0
|
||||
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
|
||||
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
|
||||
|
||||
|
||||
|
||||
@ -81,49 +81,32 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# Case 2: remove this reshape if it produces the original shape
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if self.reshape_all_dims_equivalent(shape, input_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice.Tensor):
|
||||
# python slicing semantics are different from reshape
|
||||
# Don't treat -1 as inferred dimension
|
||||
input, dim_index, start, end = node.args[:4]
|
||||
# remove reshape/slice if it produces the original shape
|
||||
if is_func(node, torch.ops.aten.reshape.default) or is_func(
|
||||
node, torch.ops.aten.slice.Tensor
|
||||
):
|
||||
input = node.args[0]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
|
||||
if output_shape == input_shape:
|
||||
if self.all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if base_shape == view_shape:
|
||||
if self.all_dims_equivalent(base_shape, view_shape):
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Reshape helpers ----------------------
|
||||
def reshape_dims_equivalent(
|
||||
self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt]
|
||||
# ---------------------- Shape comparison helpers ----------------------
|
||||
def dims_equivalent(
|
||||
self, dim: Union[int, SymInt], i_dim: Union[int, SymInt]
|
||||
) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
@ -131,27 +114,24 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
There are two cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
2. The dimensions both correspond to the same SymInt
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
||||
# Case 1
|
||||
if isinstance(i_dim, int) and isinstance(dim, int):
|
||||
return dim == i_dim
|
||||
# Case 2
|
||||
if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
|
||||
return dim == i_dim
|
||||
return False
|
||||
|
||||
def reshape_all_dims_equivalent(
|
||||
self,
|
||||
dims: Iterable[Union[int, torch.fx.Node]],
|
||||
i_dims: Iterable[Union[int, SymInt]],
|
||||
def all_dims_equivalent(
|
||||
self, dims: Iterable[Union[int, SymInt]], i_dims: Iterable[Union[int, SymInt]]
|
||||
) -> bool:
|
||||
return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
# Different ranks can't be equivalent
|
||||
return False
|
||||
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user