diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 51f8ddd566d56..e6f6a6733017b 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -2,44 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn -from torch.library import Library +# This import automatically registers torch ops for testing (like silly.attention) +import tests.compile.testing_ops from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa BATCH_SIZE = 32 MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode):