Refactor duplicate torch operation registrations to use shared module

Instead of changing library names (not scalable), create a shared test_operations.py module that:
- Provides a single "silly" library for all compilation tests
- Registers a unified attention operation that can handle both standard and counting modes
- Eliminates duplicate registration errors when running all tests together
- Maintains backward compatibility with existing test behavior

Addresses feedback to make the solution more scalable and maintainable.

Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2025-08-20 14:12:37 +00:00
parent 47dcf0940f
commit 2c81fbbb3c
4 changed files with 110 additions and 89 deletions

View File

@ -6,7 +6,6 @@ are compiled and graph captured separately.
"""
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
@ -16,10 +15,10 @@ from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly_multiple", "FRAGMENT") # noqa
# Import shared test operations
# The standard attention operation is automatically registered when imported
import tests.compile.test_operations
BATCH_SIZE = 32
MLP_SIZE = 128
@ -27,27 +26,6 @@ HIDDEN_SIZE = 1024
RANDOM_SEED = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
@support_torch_compile
class ParentModel(nn.Module):
@ -90,7 +68,7 @@ class Attention(nn.Module):
x = self.pre_attn(x)
x = self.rms_norm_ref(x)
attn_output = torch.empty_like(x)
torch.ops.silly_multiple.attention(x, x, x, attn_output)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = self.rms_norm_ref(x)
x = self.post_attn(x)
@ -188,7 +166,7 @@ def test_ignore_torch_compile_decorator():
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly_multiple.attention(x, x, x, attn_output)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x
@ -347,4 +325,4 @@ def test_multi_graph_piecewise_compile_outputs_equal():
assert torch.allclose(outputs[0], outputs[1])
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
assert torch.equal(outputs[0], outputs[2])
assert torch.equal(outputs[0], outputs[2])

View File

@ -7,7 +7,6 @@ can exactly calculate the expected output and side effects.
import pytest
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
@ -15,36 +14,15 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
global_counter = 0
# create a library to hold the custom op
silly_lib = Library("silly_simple", "FRAGMENT") # noqa
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
# Import shared test operations
from tests.compile.test_operations import (
get_global_counter, reset_global_counter, enable_counting_mode
)
# Enable counting mode for this test's specific behavior
enable_counting_mode()
@support_torch_compile
class SillyModel(nn.Module):
@ -66,12 +44,12 @@ class SillyModel(nn.Module):
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly_simple.attention(x, x, x, out)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly_simple.attention(x, x, x, out)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x + 1
return x
@ -121,13 +99,12 @@ def test_simple_piecewise_compile(use_inductor):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
reset_global_counter()
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

View File

@ -14,38 +14,16 @@ from typing import Any, Optional
import pytest
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
silly_lib = Library("silly_toy_llama", "FRAGMENT") # noqa
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)
# Import shared test operations
# The standard attention operation is automatically registered when imported
import tests.compile.test_operations
@dataclass
@ -160,7 +138,7 @@ class LlamaAttention(nn.Module):
k = k + positions.unsqueeze(1)
attn_output = torch.empty_like(q)
torch.ops.silly_toy_llama.attention(q, k, v, attn_output)
torch.ops.silly.attention(q, k, v, attn_output)
output = self.output_projection(attn_output)
return output
@ -492,4 +470,4 @@ def benchmark():
if __name__ == "__main__":
benchmark()
benchmark()

View File

@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Shared PyTorch custom operations for compilation tests.
This module provides a centralized place to define and register custom
PyTorch operations used across multiple compilation tests. This avoids
duplicate operation registrations that would cause RuntimeErrors when
running tests together.
The main "attention" operation is automatically registered when this module
is imported. Individual test files can access additional functionality
through helper functions.
"""
import torch
from torch.library import Library
from vllm.utils import direct_register_custom_op
# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations
silly_lib = Library("silly", "FRAGMENT")
# Global state for test_simple.py compatibility
_global_counter = 0
_use_counting_mode = False
def get_global_counter():
"""Get the current global counter value (for test_simple.py)"""
return _global_counter
def reset_global_counter():
"""Reset the global counter to 0 (for test_simple.py)"""
global _global_counter
_global_counter = 0
def enable_counting_mode():
"""Enable counting mode for test_simple.py"""
global _use_counting_mode
_use_counting_mode = True
reset_global_counter()
def disable_counting_mode():
"""Disable counting mode"""
global _use_counting_mode
_use_counting_mode = False
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""
Unified attention implementation that can handle both standard and counting modes.
"""
global _global_counter, _use_counting_mode
if _use_counting_mode:
# Counting mode for test_simple.py
_global_counter += 1
print(f"global_counter={_global_counter}")
out.copy_(q)
out[0] += 1
else:
# Standard mode for test_multiple_graphs.py and test_toy_llama.py
out.copy_(q)
out += k
out += v
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""Fake implementation for testing"""
return
# Register the unified attention operation
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)