Merge 1061b876bc0393e825fe22c68e5ae9247e4e1a86 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Sachin Kumar Singh 2025-12-25 00:06:37 +00:00 committed by GitHub
commit a0b5c070d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -776,7 +776,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = seq_lens.cpu()
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = (
dcp_local_seq_lens.cpu() if dcp_local_seq_lens is not None else None
)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
@ -989,6 +993,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
if dcp_local_seq_lens_cpu is None:
dcp_local_seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_local_block_size,
)
if dcp_local_seq_lens is None:
assert dcp_local_seq_lens_cpu is not None
dcp_local_seq_lens = dcp_local_seq_lens_cpu.to(
seq_lens.device, non_blocking=True
)
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens
# After DCP distribution, the maximum number of tokens for any rank is