From 2c81fbbb3ce137363829b0109a23cba2b2a16df2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:12:37 +0000 Subject: [PATCH] Refactor duplicate torch operation registrations to use shared module Instead of changing library names (not scalable), create a shared test_operations.py module that: - Provides a single "silly" library for all compilation tests - Registers a unified attention operation that can handle both standard and counting modes - Eliminates duplicate registration errors when running all tests together - Maintains backward compatibility with existing test behavior Addresses feedback to make the solution more scalable and maintainable. Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> --- .../compile/piecewise/test_multiple_graphs.py | 34 ++----- tests/compile/piecewise/test_simple.py | 45 +++------- tests/compile/piecewise/test_toy_llama.py | 32 ++----- tests/compile/test_operations.py | 88 +++++++++++++++++++ 4 files changed, 110 insertions(+), 89 deletions(-) create mode 100644 tests/compile/test_operations.py diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index a5e18a2c0dcec..d70dc8811db8d 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -6,7 +6,6 @@ are compiled and graph captured separately. """ import torch from torch import nn -from torch.library import Library from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter @@ -16,10 +15,10 @@ from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly_multiple", "FRAGMENT") # noqa +# Import shared test operations +# The standard attention operation is automatically registered when imported +import tests.compile.test_operations BATCH_SIZE = 32 MLP_SIZE = 128 @@ -27,27 +26,6 @@ HIDDEN_SIZE = 1024 RANDOM_SEED = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class ParentModel(nn.Module): @@ -90,7 +68,7 @@ class Attention(nn.Module): x = self.pre_attn(x) x = self.rms_norm_ref(x) attn_output = torch.empty_like(x) - torch.ops.silly_multiple.attention(x, x, x, attn_output) + torch.ops.silly.attention(x, x, x, attn_output) x = attn_output x = self.rms_norm_ref(x) x = self.post_attn(x) @@ -188,7 +166,7 @@ def test_ignore_torch_compile_decorator(): def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + x attn_output = torch.empty_like(x) - torch.ops.silly_multiple.attention(x, x, x, attn_output) + torch.ops.silly.attention(x, x, x, attn_output) x = attn_output x = x * 3 return x @@ -347,4 +325,4 @@ def test_multi_graph_piecewise_compile_outputs_equal(): assert torch.allclose(outputs[0], outputs[1]) # Expect bitwise equivalence using inductor w/ and w/o cudagraph - assert torch.equal(outputs[0], outputs[2]) + assert torch.equal(outputs[0], outputs[2]) \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 2920d2cdd7ae7..f8717d7245547 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -7,7 +7,6 @@ can exactly calculate the expected output and side effects. import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile @@ -15,36 +14,15 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -global_counter = 0 - -# create a library to hold the custom op -silly_lib = Library("silly_simple", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - global global_counter - global_counter += 1 - print(f"{global_counter=}") - out.copy_(q) - out[0] += 1 - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, +# Import shared test operations +from tests.compile.test_operations import ( + get_global_counter, reset_global_counter, enable_counting_mode ) +# Enable counting mode for this test's specific behavior +enable_counting_mode() + @support_torch_compile class SillyModel(nn.Module): @@ -66,12 +44,12 @@ class SillyModel(nn.Module): x = x + 1 x = x + 2 out = torch.empty_like(x) - torch.ops.silly_simple.attention(x, x, x, out) + torch.ops.silly.attention(x, x, x, out) x = out x = x - 2 x = x - 1 out = torch.empty_like(x) - torch.ops.silly_simple.attention(x, x, x, out) + torch.ops.silly.attention(x, x, x, out) x = out x = x + 1 return x @@ -121,13 +99,12 @@ def test_simple_piecewise_compile(use_inductor): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() - global global_counter - global_counter = 0 + reset_global_counter() with set_forward_context( None, vllm_config=vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, batch_descriptor=BatchDescriptor(num_tokens=2, )): output = model(input) - assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert get_global_counter() == 2 + assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) \ No newline at end of file diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 0e7ab819c65e8..b2889f03a8061 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -14,38 +14,16 @@ from typing import Any, Optional import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly_toy_llama", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# Import shared test operations +# The standard attention operation is automatically registered when imported +import tests.compile.test_operations @dataclass @@ -160,7 +138,7 @@ class LlamaAttention(nn.Module): k = k + positions.unsqueeze(1) attn_output = torch.empty_like(q) - torch.ops.silly_toy_llama.attention(q, k, v, attn_output) + torch.ops.silly.attention(q, k, v, attn_output) output = self.output_projection(attn_output) return output @@ -492,4 +470,4 @@ def benchmark(): if __name__ == "__main__": - benchmark() + benchmark() \ No newline at end of file diff --git a/tests/compile/test_operations.py b/tests/compile/test_operations.py new file mode 100644 index 0000000000000..59561d6fa21c2 --- /dev/null +++ b/tests/compile/test_operations.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared PyTorch custom operations for compilation tests. + +This module provides a centralized place to define and register custom +PyTorch operations used across multiple compilation tests. This avoids +duplicate operation registrations that would cause RuntimeErrors when +running tests together. + +The main "attention" operation is automatically registered when this module +is imported. Individual test files can access additional functionality +through helper functions. +""" + +import torch +from torch.library import Library + +from vllm.utils import direct_register_custom_op + +# Shared library for all compilation test operations +# Using "silly" namespace to match existing test expectations +silly_lib = Library("silly", "FRAGMENT") + + +# Global state for test_simple.py compatibility +_global_counter = 0 +_use_counting_mode = False + + +def get_global_counter(): + """Get the current global counter value (for test_simple.py)""" + return _global_counter + + +def reset_global_counter(): + """Reset the global counter to 0 (for test_simple.py)""" + global _global_counter + _global_counter = 0 + + +def enable_counting_mode(): + """Enable counting mode for test_simple.py""" + global _use_counting_mode + _use_counting_mode = True + reset_global_counter() + + +def disable_counting_mode(): + """Disable counting mode""" + global _use_counting_mode + _use_counting_mode = False + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """ + Unified attention implementation that can handle both standard and counting modes. + """ + global _global_counter, _use_counting_mode + + if _use_counting_mode: + # Counting mode for test_simple.py + _global_counter += 1 + print(f"global_counter={_global_counter}") + out.copy_(q) + out[0] += 1 + else: + # Standard mode for test_multiple_graphs.py and test_toy_llama.py + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """Fake implementation for testing""" + return + + +# Register the unified attention operation +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) \ No newline at end of file