[BugFix] fix: aot passes kvcache dtype information (#19750)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
Mickaël Seznec 2025-08-01 07:45:02 +02:00 committed by GitHub
parent 82de9b9d46
commit e1a7fe4af5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -99,6 +99,13 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class FlashAttentionMetadata:
@ -161,6 +168,7 @@ class FlashAttentionMetadataBuilder(
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
self.parallel_config)
self.kv_cache_dtype = kv_cache_spec.dtype
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
@ -239,17 +247,24 @@ class FlashAttentionMetadataBuilder(
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype)
else:
qkv_dtype = self.kv_cache_dtype
if aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.block_size,
cache_seqlens=seqlens,
qkv_dtype=qkv_dtype,
cu_seqlens_q=cu_query_lens,
page_size=self.block_size,
causal=causal,
window_size=self.aot_sliding_window,
num_splits=self.max_num_splits,
@ -474,8 +489,10 @@ class FlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(