mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
Aiter mha fp8 fix (#24991)
Signed-off-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com>
This commit is contained in:
parent
fedb75fa27
commit
1a456c7c90
@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention):
|
||||
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
|
||||
|
||||
if "fp8" in kv_cache_dtype:
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||
# use blocksparse paged attention
|
||||
|
||||
@ -479,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user