mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 11:15:29 +08:00
[BugFix] fix: aot passes kvcache dtype information (#19750)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
parent
82de9b9d46
commit
e1a7fe4af5
@ -99,6 +99,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||||
return stride_order
|
return stride_order
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
|
||||||
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashAttentionMetadata:
|
class FlashAttentionMetadata:
|
||||||
@ -161,6 +168,7 @@ class FlashAttentionMetadataBuilder(
|
|||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
|
self.kv_cache_dtype = kv_cache_spec.dtype
|
||||||
self.headdim = self.model_config.get_head_size()
|
self.headdim = self.model_config.get_head_size()
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
|
|
||||||
@ -239,17 +247,24 @@ class FlashAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||||
max_seq_len, causal):
|
max_seq_len, causal):
|
||||||
|
cache_dtype = self.cache_config.cache_dtype
|
||||||
|
if cache_dtype.startswith("fp8"):
|
||||||
|
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||||
|
cache_dtype)
|
||||||
|
else:
|
||||||
|
qkv_dtype = self.kv_cache_dtype
|
||||||
if aot_schedule:
|
if aot_schedule:
|
||||||
return get_scheduler_metadata(
|
return get_scheduler_metadata(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seqlen_q=max_query_len,
|
max_seqlen_q=max_query_len,
|
||||||
max_seqlen_k=max_seq_len,
|
max_seqlen_k=max_seq_len,
|
||||||
cache_seqlens=seqlens,
|
|
||||||
num_heads_q=self.num_heads_q,
|
num_heads_q=self.num_heads_q,
|
||||||
num_heads_kv=self.num_heads_kv,
|
num_heads_kv=self.num_heads_kv,
|
||||||
headdim=self.headdim,
|
headdim=self.headdim,
|
||||||
page_size=self.block_size,
|
cache_seqlens=seqlens,
|
||||||
|
qkv_dtype=qkv_dtype,
|
||||||
cu_seqlens_q=cu_query_lens,
|
cu_seqlens_q=cu_query_lens,
|
||||||
|
page_size=self.block_size,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_size=self.aot_sliding_window,
|
window_size=self.aot_sliding_window,
|
||||||
num_splits=self.max_num_splits,
|
num_splits=self.max_num_splits,
|
||||||
@ -474,8 +489,10 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
self.kv_cache_dtype)
|
||||||
|
key_cache = key_cache.view(dtype)
|
||||||
|
value_cache = value_cache.view(dtype)
|
||||||
num_tokens, num_heads, head_size = query.shape
|
num_tokens, num_heads, head_size = query.shape
|
||||||
query, _ = ops.scaled_fp8_quant(
|
query, _ = ops.scaled_fp8_quant(
|
||||||
query.reshape(
|
query.reshape(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user