diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index bb66ef5529b1..1e8a882a7f3e 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -15,6 +15,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.utils.torch_utils import _is_torch_equal_or_newer +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 + def test_version(): # Test the version comparison logic using the private function @@ -257,15 +260,6 @@ def test_should_split(): splitting_ops = ["aten::add.Tensor"] assert not should_split(node, splitting_ops) - @torch.library.custom_op( - "silly::attention", - mutates_args=["out"], - ) - def attention( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor - ) -> None: - out.copy_(q + k + v) - q, k, v, out = [torch.randn(1)] * 4 # supports custom ops as OpOverloadPacket