mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 20:17:53 +08:00
[ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (#14065)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
3b5567a209
commit
b28246f6ff
@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import (
|
||||||
|
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend):
|
|||||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||||
|
return FlashAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
class ROCmAttentionImpl(AttentionImpl):
|
class ROCmAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user