Aiter mha fp8 fix (#24991)

Signed-off-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
This commit is contained in:
Douglas Lehr 2025-09-17 17:29:14 -05:00 committed by GitHub
parent fedb75fa27
commit 1a456c7c90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention):
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention

View File

@ -479,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc