diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index f8717d7245547..f235bf4a6fbbf 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -17,12 +17,9 @@ from vllm.forward_context import BatchDescriptor, set_forward_context # Import shared test operations from tests.compile.test_operations import ( - get_global_counter, reset_global_counter, enable_counting_mode + get_global_counter, reset_global_counter ) -# Enable counting mode for this test's specific behavior -enable_counting_mode() - @support_torch_compile class SillyModel(nn.Module): @@ -36,9 +33,8 @@ class SillyModel(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Overall effect: - x += 1 - x[0] += 2 + Overall effect with unified attention implementation: + input [0., 0.] -> final output [19., 19.] global_counter += 2 """ x = x + 1 @@ -107,4 +103,4 @@ def test_simple_piecewise_compile(use_inductor): batch_descriptor=BatchDescriptor(num_tokens=2, )): output = model(input) assert get_global_counter() == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) \ No newline at end of file + assert torch.allclose(output.cpu(), torch.tensor([19., 19.])) \ No newline at end of file diff --git a/tests/compile/test_operations.py b/tests/compile/test_operations.py index 59561d6fa21c2..32aa4108e6b2a 100644 --- a/tests/compile/test_operations.py +++ b/tests/compile/test_operations.py @@ -9,7 +9,7 @@ duplicate operation registrations that would cause RuntimeErrors when running tests together. The main "attention" operation is automatically registered when this module -is imported. Individual test files can access additional functionality +is imported. Individual test files can access the global counter functionality through helper functions. """ @@ -23,53 +23,34 @@ from vllm.utils import direct_register_custom_op silly_lib = Library("silly", "FRAGMENT") -# Global state for test_simple.py compatibility +# Global counter that all tests can use or ignore _global_counter = 0 -_use_counting_mode = False def get_global_counter(): - """Get the current global counter value (for test_simple.py)""" + """Get the current global counter value""" return _global_counter def reset_global_counter(): - """Reset the global counter to 0 (for test_simple.py)""" + """Reset the global counter to 0""" global _global_counter _global_counter = 0 -def enable_counting_mode(): - """Enable counting mode for test_simple.py""" - global _use_counting_mode - _use_counting_mode = True - reset_global_counter() - - -def disable_counting_mode(): - """Disable counting mode""" - global _use_counting_mode - _use_counting_mode = False - - def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: """ - Unified attention implementation that can handle both standard and counting modes. + Unified attention implementation that depends on all inputs and affects the output. + Always increments a global counter that tests can use or ignore. """ - global _global_counter, _use_counting_mode + global _global_counter - if _use_counting_mode: - # Counting mode for test_simple.py - _global_counter += 1 - print(f"global_counter={_global_counter}") - out.copy_(q) - out[0] += 1 - else: - # Standard mode for test_multiple_graphs.py and test_toy_llama.py - out.copy_(q) - out += k - out += v + # Always increment the global counter + _global_counter += 1 + + # Unified implementation that depends on all inputs + out.copy_(q + k + v) def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,