mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Minor] avoid register new custom and just import silly_attn (#28578)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
parent
c9a3a02149
commit
fd75d3e8c0
@ -15,6 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
||||
|
||||
def test_version():
|
||||
# Test the version comparison logic using the private function
|
||||
@ -257,15 +260,6 @@ def test_should_split():
|
||||
splitting_ops = ["aten::add.Tensor"]
|
||||
assert not should_split(node, splitting_ops)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"silly::attention",
|
||||
mutates_args=["out"],
|
||||
)
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
out.copy_(q + k + v)
|
||||
|
||||
q, k, v, out = [torch.randn(1)] * 4
|
||||
|
||||
# supports custom ops as OpOverloadPacket
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user