From dbfbf9f32445ffadcbb731adef00b4a393612be9 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 23 Oct 2025 15:58:15 -0400 Subject: [PATCH] [Attention] Fix FlashMLA metadata builder arguments for q_len > 1 (#27368) Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashmla.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 3e481af29544..1f98204031ed 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -120,9 +120,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): num_decode_tokens: int, dcp_tot_seq_lens_device: torch.Tensor | None, ) -> FlashMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + # we use the max but all should be the same due to uniform length requirement + max_query_len = query_lens_cpu.max().item() + num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1 tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, - self.num_q_heads, + num_q_tokens_per_head_k, 1, # MQA for the decode path is_fp8_kvcache=self.is_fp8_kvcache, )