mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Minor] avoid register new custom and just import silly_attn (#28578)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
parent
c9a3a02149
commit
fd75d3e8c0
@ -15,6 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||||
|
|
||||||
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
|
from . import silly_attention # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def test_version():
|
def test_version():
|
||||||
# Test the version comparison logic using the private function
|
# Test the version comparison logic using the private function
|
||||||
@ -257,15 +260,6 @@ def test_should_split():
|
|||||||
splitting_ops = ["aten::add.Tensor"]
|
splitting_ops = ["aten::add.Tensor"]
|
||||||
assert not should_split(node, splitting_ops)
|
assert not should_split(node, splitting_ops)
|
||||||
|
|
||||||
@torch.library.custom_op(
|
|
||||||
"silly::attention",
|
|
||||||
mutates_args=["out"],
|
|
||||||
)
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
|
||||||
) -> None:
|
|
||||||
out.copy_(q + k + v)
|
|
||||||
|
|
||||||
q, k, v, out = [torch.randn(1)] * 4
|
q, k, v, out = [torch.randn(1)] * 4
|
||||||
|
|
||||||
# supports custom ops as OpOverloadPacket
|
# supports custom ops as OpOverloadPacket
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user