[ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (#14065)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-02-28 23:18:32 -08:00 committed by GitHub
parent 3b5567a209
commit b28246f6ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
logger = init_logger(__name__)
@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend):
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
class ROCmAttentionImpl(AttentionImpl):