From 1a456c7c90afdb534d1203d7e4ea5747aada801c Mon Sep 17 00:00:00 2001 From: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com> Date: Wed, 17 Sep 2025 17:29:14 -0500 Subject: [PATCH] Aiter mha fp8 fix (#24991) Signed-off-by: Doug Lehr Co-authored-by: Doug Lehr --- vllm/attention/ops/rocm_aiter_paged_attn.py | 4 ++-- vllm/v1/attention/backends/rocm_aiter_fa.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index ad97152e208b..2a0336de8cf7 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention): blocksparse_head_sliding_step=blocksparse_head_sliding_step) if "fp8" in kv_cache_dtype: - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index a4e2758bd311..8eb3505cf274 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -479,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc