[Fix] Fix MLA attention crash when using DP with DCP

Signed-off-by: Sachin Singh <sachinsingh@digitalocean.com>
This commit is contained in:
Sachin Singh 2025-12-22 18:33:43 +05:30 committed by Sachin Singh
parent 6b16fff01b
commit ac80c90235

View File

@ -989,6 +989,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
if dcp_local_seq_lens_cpu is None:
dcp_local_seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_local_block_size,
)
if dcp_local_seq_lens is None:
dcp_local_seq_lens = dcp_local_seq_lens_cpu.to(
seq_lens.device, non_blocking=True
)
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is