mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 14:55:02 +08:00
[Attention] Fix FlashMLA metadata builder arguments for q_len > 1 (#27368)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
ca76486a16
commit
dbfbf9f324
@ -120,9 +120,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||||
) -> FlashMLADecodeMetadata:
|
) -> FlashMLADecodeMetadata:
|
||||||
|
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
# we use the max but all should be the same due to uniform length requirement
|
||||||
|
max_query_len = query_lens_cpu.max().item()
|
||||||
|
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||||
seq_lens_device,
|
seq_lens_device,
|
||||||
self.num_q_heads,
|
num_q_tokens_per_head_k,
|
||||||
1, # MQA for the decode path
|
1, # MQA for the decode path
|
||||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user