mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 11:07:08 +08:00
Fix duplicate torch operation registrations in tests/compile
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
e263eccfae
commit
9d6f0372e5
@ -19,7 +19,7 @@ from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
silly_lib = Library("silly_multiple", "FRAGMENT") # noqa
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
@ -188,7 +188,7 @@ def test_ignore_torch_compile_decorator():
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
torch.ops.silly_multiple.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.utils import direct_register_custom_op
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
silly_lib = Library("silly_simple", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
@ -66,12 +66,12 @@ class SillyModel(nn.Module):
|
||||
x = x + 1
|
||||
x = x + 2
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
torch.ops.silly_simple.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x - 2
|
||||
x = x - 1
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
torch.ops.silly_simple.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
@ -24,7 +24,7 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
silly_lib = Library("silly_toy_llama", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
@ -160,7 +160,7 @@ class LlamaAttention(nn.Module):
|
||||
k = k + positions.unsqueeze(1)
|
||||
|
||||
attn_output = torch.empty_like(q)
|
||||
torch.ops.silly.attention(q, k, v, attn_output)
|
||||
torch.ops.silly_toy_llama.attention(q, k, v, attn_output)
|
||||
|
||||
output = self.output_projection(attn_output)
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user