diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4c2a6c6b985b..3f9afa67aef7 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -99,6 +99,13 @@ class FlashAttentionBackend(AttentionBackend): raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order + @staticmethod + def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @dataclass class FlashAttentionMetadata: @@ -161,6 +168,7 @@ class FlashAttentionMetadataBuilder( self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) + self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size @@ -239,17 +247,24 @@ class FlashAttentionMetadataBuilder( def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + cache_dtype) + else: + qkv_dtype = self.kv_cache_dtype if aot_schedule: return get_scheduler_metadata( batch_size=batch_size, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - cache_seqlens=seqlens, num_heads_q=self.num_heads_q, num_heads_kv=self.num_heads_kv, headdim=self.headdim, - page_size=self.block_size, + cache_seqlens=seqlens, + qkv_dtype=qkv_dtype, cu_seqlens_q=cu_query_lens, + page_size=self.block_size, causal=causal, window_size=self.aot_sliding_window, num_splits=self.max_num_splits, @@ -474,8 +489,10 @@ class FlashAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) + dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( query.reshape(