diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 0f3fabf05fc28..5c7d759b1812d 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, FlashAttentionMetadataBuilder) logger = init_logger(__name__) @@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend): def use_cascade_attention(*args, **kwargs) -> bool: return False + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + class ROCmAttentionImpl(AttentionImpl):