Simplify operation implementation: remove mode switching, always use global counter

Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2025-08-20 23:33:41 +00:00
parent 2c81fbbb3c
commit 865b0bfafd
2 changed files with 16 additions and 39 deletions

View File

@ -17,12 +17,9 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
# Import shared test operations
from tests.compile.test_operations import (
get_global_counter, reset_global_counter, enable_counting_mode
get_global_counter, reset_global_counter
)
# Enable counting mode for this test's specific behavior
enable_counting_mode()
@support_torch_compile
class SillyModel(nn.Module):
@ -36,9 +33,8 @@ class SillyModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
Overall effect with unified attention implementation:
input [0., 0.] -> final output [19., 19.]
global_counter += 2
"""
x = x + 1
@ -107,4 +103,4 @@ def test_simple_piecewise_compile(use_inductor):
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
assert torch.allclose(output.cpu(), torch.tensor([19., 19.]))

View File

@ -9,7 +9,7 @@ duplicate operation registrations that would cause RuntimeErrors when
running tests together.
The main "attention" operation is automatically registered when this module
is imported. Individual test files can access additional functionality
is imported. Individual test files can access the global counter functionality
through helper functions.
"""
@ -23,53 +23,34 @@ from vllm.utils import direct_register_custom_op
silly_lib = Library("silly", "FRAGMENT")
# Global state for test_simple.py compatibility
# Global counter that all tests can use or ignore
_global_counter = 0
_use_counting_mode = False
def get_global_counter():
"""Get the current global counter value (for test_simple.py)"""
"""Get the current global counter value"""
return _global_counter
def reset_global_counter():
"""Reset the global counter to 0 (for test_simple.py)"""
"""Reset the global counter to 0"""
global _global_counter
_global_counter = 0
def enable_counting_mode():
"""Enable counting mode for test_simple.py"""
global _use_counting_mode
_use_counting_mode = True
reset_global_counter()
def disable_counting_mode():
"""Disable counting mode"""
global _use_counting_mode
_use_counting_mode = False
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
"""
Unified attention implementation that can handle both standard and counting modes.
Unified attention implementation that depends on all inputs and affects the output.
Always increments a global counter that tests can use or ignore.
"""
global _global_counter, _use_counting_mode
global _global_counter
if _use_counting_mode:
# Counting mode for test_simple.py
_global_counter += 1
print(f"global_counter={_global_counter}")
out.copy_(q)
out[0] += 1
else:
# Standard mode for test_multiple_graphs.py and test_toy_llama.py
out.copy_(q)
out += k
out += v
# Always increment the global counter
_global_counter += 1
# Unified implementation that depends on all inputs
out.copy_(q + k + v)
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,