diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1ee1dea729d9e..da3d9ff32830c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -22,7 +22,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform -from vllm.platforms.rocm import use_rocm_custom_paged_attention if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -886,6 +885,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads + from vllm.platforms.rocm import use_rocm_custom_paged_attention use_custom = use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len, self.sliding_window, diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index dc10d7eca9c2a..e5b90a8b27558 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -11,7 +11,6 @@ import torch from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd @@ -296,6 +295,7 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) + from vllm.platforms.rocm import use_rocm_custom_paged_attention use_custom = use_rocm_custom_paged_attention( query.dtype, head_size,