mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:46:08 +08:00
[BugFix] fix: aot passes kvcache dtype information (#19750)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
parent
82de9b9d46
commit
e1a7fe4af5
@ -99,6 +99,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
return torch.float8_e4m3fn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
@ -161,6 +168,7 @@ class FlashAttentionMetadataBuilder(
|
||||
self.parallel_config)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
self.kv_cache_dtype = kv_cache_spec.dtype
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
@ -239,17 +247,24 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
cache_dtype)
|
||||
else:
|
||||
qkv_dtype = self.kv_cache_dtype
|
||||
if aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=batch_size,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
cache_seqlens=seqlens,
|
||||
num_heads_q=self.num_heads_q,
|
||||
num_heads_kv=self.num_heads_kv,
|
||||
headdim=self.headdim,
|
||||
page_size=self.block_size,
|
||||
cache_seqlens=seqlens,
|
||||
qkv_dtype=qkv_dtype,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
page_size=self.block_size,
|
||||
causal=causal,
|
||||
window_size=self.aot_sliding_window,
|
||||
num_splits=self.max_num_splits,
|
||||
@ -474,8 +489,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
self.kv_cache_dtype)
|
||||
key_cache = key_cache.view(dtype)
|
||||
value_cache = value_cache.view(dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user