mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 03:27:04 +08:00
[Compile] Fix noop_elimination pass and add tests for noop_elimination (#24880)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
45bfa49cb8
commit
5bcc153d7b
@ -394,6 +394,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_async_tp.py
|
- pytest -v -s compile/test_async_tp.py
|
||||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s compile/test_decorator.py
|
- pytest -v -s compile/test_decorator.py
|
||||||
|
- pytest -v -s compile/test_noop_elimination.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
|
|||||||
@ -64,4 +64,8 @@ class TestBackend:
|
|||||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||||
|
|
||||||
|
def op_count(self, op: OpOverload, before=False) -> int:
|
||||||
|
graph = self.graph_pre_pass if before else self.graph_post_pass
|
||||||
|
return len(list(find_op_nodes(op, graph)))
|
||||||
|
|||||||
106
tests/compile/test_noop_elimination.py
Normal file
106
tests/compile/test_noop_elimination.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||||
|
VllmConfig)
|
||||||
|
|
||||||
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype",
|
||||||
|
[torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||||
|
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Chain of reshapes
|
||||||
|
y = x.reshape(-1, 128, 32)
|
||||||
|
z = y.reshape(-1, 4096)
|
||||||
|
# No-op reshape
|
||||||
|
a = z.reshape(-1, 4096)
|
||||||
|
# Final reshape that should remain
|
||||||
|
b = a.reshape(-1, 128, 32)
|
||||||
|
# No-op slice
|
||||||
|
c = b[0:b.shape[0]]
|
||||||
|
# The pass should replace the result of this op with `c`.
|
||||||
|
d = torch.slice_scatter(
|
||||||
|
torch.ones_like(c), # Dummy tensor to be scattered into
|
||||||
|
c, # Source tensor
|
||||||
|
0, # dim
|
||||||
|
0, # start
|
||||||
|
c.shape[0], # end
|
||||||
|
)
|
||||||
|
return d
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
pass_config=PassConfig(enable_noop=True),
|
||||||
|
))
|
||||||
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|
||||||
|
backend = TestBackend(noop_pass)
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
# First dimension dynamic
|
||||||
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
|
|
||||||
|
result = model(x)
|
||||||
|
|
||||||
|
model2 = torch.compile(model, backend=backend)
|
||||||
|
result2 = model2(x)
|
||||||
|
|
||||||
|
ATOL, RTOL = (2e-3, 2e-3)
|
||||||
|
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
|
# The no-op reshape and slice should be eliminated.
|
||||||
|
# The chain of reshapes should be fused into a single reshape.
|
||||||
|
assert backend.op_count(torch.ops.aten.reshape.default) == 1
|
||||||
|
assert backend.op_count(torch.ops.aten.slice.Tensor) == 0
|
||||||
|
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_noop_slice_preserved():
|
||||||
|
"""Ensure that a slice with end=-1 (dropping last row) is NOT eliminated.
|
||||||
|
|
||||||
|
Regression test for a bug where end=-1 was treated like an inferred
|
||||||
|
dimension (reshape semantics) leading to incorrect elimination.
|
||||||
|
"""
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
x = torch.randn(16, 16)
|
||||||
|
|
||||||
|
class SliceModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
base = x.clone()
|
||||||
|
src = torch.ones(15, 16)
|
||||||
|
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
||||||
|
return x[0:-1, :], y
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
pass_config=PassConfig(enable_noop=True),
|
||||||
|
))
|
||||||
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
backend = TestBackend(noop_pass)
|
||||||
|
model = SliceModel()
|
||||||
|
ref = model(x)
|
||||||
|
compiled = torch.compile(model, backend=backend)
|
||||||
|
out = compiled(x)
|
||||||
|
torch.testing.assert_close(ref, out)
|
||||||
|
# The slice should remain (not a no-op).
|
||||||
|
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
|
||||||
|
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1
|
||||||
@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
scaled_mm: "f16[s0, 4096]" = ...
|
scaled_mm: "f16[s0, 4096]" = ...
|
||||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||||
out: "f16[s0, 4096]" = at[1]
|
out: "f16[s0, 4096]" = at[1]
|
||||||
|
|
||||||
TODO(luka): This is currently tested in test_fusion,
|
|
||||||
but separate tests could be good.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
@ -96,17 +93,19 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
# Invalid reshape args, skip
|
# Invalid reshape args, skip
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.all_dims_equivalent(shape, input_shape):
|
if self.reshape_all_dims_equivalent(shape, input_shape):
|
||||||
node.replace_all_uses_with(input)
|
node.replace_all_uses_with(input)
|
||||||
graph.erase_node(node)
|
graph.erase_node(node)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
elif is_func(node, torch.ops.aten.slice.Tensor):
|
elif is_func(node, torch.ops.aten.slice.Tensor):
|
||||||
|
# python slicing semantics are different from reshape
|
||||||
|
# Don't treat -1 as inferred dimension
|
||||||
input, dim_index, start, end = node.args[:4]
|
input, dim_index, start, end = node.args[:4]
|
||||||
input_shape = input.meta["val"].shape
|
input_shape = input.meta["val"].shape
|
||||||
i_dim = input_shape[dim_index]
|
output_shape = node.meta["val"].shape
|
||||||
|
|
||||||
if start == 0 and self.dims_equivalent(end, i_dim):
|
if output_shape == input_shape:
|
||||||
node.replace_all_uses_with(input)
|
node.replace_all_uses_with(input)
|
||||||
graph.erase_node(node)
|
graph.erase_node(node)
|
||||||
count += 1
|
count += 1
|
||||||
@ -116,14 +115,7 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
base_shape = base.meta["val"].shape
|
base_shape = base.meta["val"].shape
|
||||||
view_shape = view.meta["val"].shape
|
view_shape = view.meta["val"].shape
|
||||||
|
|
||||||
view_dim = view_shape[dim_index]
|
if base_shape == view_shape:
|
||||||
|
|
||||||
# Check that view fully covers base and the full view is used
|
|
||||||
# (if the view fully covered the base after slicing but was not
|
|
||||||
# fully used, we could replace slice_scatter with a simple slice
|
|
||||||
# but that's a niche case).
|
|
||||||
if (base_shape == view_shape and start == 0
|
|
||||||
and self.dims_equivalent(end, view_dim)):
|
|
||||||
node.replace_all_uses_with(view)
|
node.replace_all_uses_with(view)
|
||||||
graph.erase_node(node)
|
graph.erase_node(node)
|
||||||
count += 1
|
count += 1
|
||||||
@ -132,13 +124,9 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
self.dump_graph(graph, "after_noop_elimination")
|
self.dump_graph(graph, "after_noop_elimination")
|
||||||
self.end_and_log()
|
self.end_and_log()
|
||||||
|
|
||||||
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
|
# ---------------------- Reshape helpers ----------------------
|
||||||
i_dims: Iterable[Union[int, SymInt]]):
|
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||||
return all(
|
i_dim: Union[int, SymInt]) -> bool:
|
||||||
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
|
||||||
|
|
||||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
|
||||||
i_dim: Union[int, SymInt]) -> bool:
|
|
||||||
"""
|
"""
|
||||||
This function checks if two dimensions are equivalent.
|
This function checks if two dimensions are equivalent.
|
||||||
:param dim: The dimension arg to reshape/slice
|
:param dim: The dimension arg to reshape/slice
|
||||||
@ -156,10 +144,18 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
In case 3, the reshape dimension is a torch.fx.Node,
|
In case 3, the reshape dimension is a torch.fx.Node,
|
||||||
and its value is a SymInt. That value is equal to the
|
and its value is a SymInt. That value is equal to the
|
||||||
input dimension.
|
input dimension.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Case 1 and 2
|
# Case 1 and 2
|
||||||
if dim == i_dim or dim == -1:
|
if dim == i_dim or dim == -1:
|
||||||
return True
|
return True
|
||||||
# Case 3
|
# Case 3
|
||||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
||||||
|
|
||||||
|
def reshape_all_dims_equivalent(
|
||||||
|
self,
|
||||||
|
dims: Iterable[Union[int, torch.fx.Node]],
|
||||||
|
i_dims: Iterable[Union[int, SymInt]],
|
||||||
|
) -> bool:
|
||||||
|
return all(
|
||||||
|
self.reshape_dims_equivalent(s, i_s)
|
||||||
|
for s, i_s in zip(dims, i_dims))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user