From b28246f6ff1684ef166d04cc9185e113a8474696 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 28 Feb 2025 23:18:32 -0800 Subject: [PATCH] [ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (#14065) Signed-off-by: Sage Moore --- vllm/v1/attention/backends/rocm_attn.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 0f3fabf05fc28..5c7d759b1812d 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, FlashAttentionMetadataBuilder) logger = init_logger(__name__) @@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend): def use_cascade_attention(*args, **kwargs) -> bool: return False + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + class ROCmAttentionImpl(AttentionImpl):