mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 07:37:03 +08:00
Simplify operation implementation: remove mode switching, always use global counter
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
2c81fbbb3c
commit
865b0bfafd
@ -17,12 +17,9 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
|
||||
# Import shared test operations
|
||||
from tests.compile.test_operations import (
|
||||
get_global_counter, reset_global_counter, enable_counting_mode
|
||||
get_global_counter, reset_global_counter
|
||||
)
|
||||
|
||||
# Enable counting mode for this test's specific behavior
|
||||
enable_counting_mode()
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
@ -36,9 +33,8 @@ class SillyModel(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Overall effect:
|
||||
x += 1
|
||||
x[0] += 2
|
||||
Overall effect with unified attention implementation:
|
||||
input [0., 0.] -> final output [19., 19.]
|
||||
global_counter += 2
|
||||
"""
|
||||
x = x + 1
|
||||
@ -107,4 +103,4 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
output = model(input)
|
||||
assert get_global_counter() == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
assert torch.allclose(output.cpu(), torch.tensor([19., 19.]))
|
||||
@ -9,7 +9,7 @@ duplicate operation registrations that would cause RuntimeErrors when
|
||||
running tests together.
|
||||
|
||||
The main "attention" operation is automatically registered when this module
|
||||
is imported. Individual test files can access additional functionality
|
||||
is imported. Individual test files can access the global counter functionality
|
||||
through helper functions.
|
||||
"""
|
||||
|
||||
@ -23,53 +23,34 @@ from vllm.utils import direct_register_custom_op
|
||||
silly_lib = Library("silly", "FRAGMENT")
|
||||
|
||||
|
||||
# Global state for test_simple.py compatibility
|
||||
# Global counter that all tests can use or ignore
|
||||
_global_counter = 0
|
||||
_use_counting_mode = False
|
||||
|
||||
|
||||
def get_global_counter():
|
||||
"""Get the current global counter value (for test_simple.py)"""
|
||||
"""Get the current global counter value"""
|
||||
return _global_counter
|
||||
|
||||
|
||||
def reset_global_counter():
|
||||
"""Reset the global counter to 0 (for test_simple.py)"""
|
||||
"""Reset the global counter to 0"""
|
||||
global _global_counter
|
||||
_global_counter = 0
|
||||
|
||||
|
||||
def enable_counting_mode():
|
||||
"""Enable counting mode for test_simple.py"""
|
||||
global _use_counting_mode
|
||||
_use_counting_mode = True
|
||||
reset_global_counter()
|
||||
|
||||
|
||||
def disable_counting_mode():
|
||||
"""Disable counting mode"""
|
||||
global _use_counting_mode
|
||||
_use_counting_mode = False
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
"""
|
||||
Unified attention implementation that can handle both standard and counting modes.
|
||||
Unified attention implementation that depends on all inputs and affects the output.
|
||||
Always increments a global counter that tests can use or ignore.
|
||||
"""
|
||||
global _global_counter, _use_counting_mode
|
||||
global _global_counter
|
||||
|
||||
if _use_counting_mode:
|
||||
# Counting mode for test_simple.py
|
||||
_global_counter += 1
|
||||
print(f"global_counter={_global_counter}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
else:
|
||||
# Standard mode for test_multiple_graphs.py and test_toy_llama.py
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
# Always increment the global counter
|
||||
_global_counter += 1
|
||||
|
||||
# Unified implementation that depends on all inputs
|
||||
out.copy_(q + k + v)
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user