[BugFix] Fix async-scheduling + FlashAttn MLA (#28990)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-11-19 10:04:07 -05:00 committed by GitHub
parent 1ffe934c8a
commit 48fc8b1e59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 18 additions and 10 deletions

View File

@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = None
if num_decodes > 0:
dcp_tot_seq_lens_device = None
if self.dcp_world_size > 1:
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
seq_lens_cpu = dcp_local_seq_lens_cpu
seq_lens = dcp_local_seq_lens
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=dcp_local_seq_lens[:num_decodes]
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
else seq_lens[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
query_start_loc_device=query_start_loc[: num_decodes + 1],
num_decode_tokens=num_decode_tokens,
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
if self.dcp_world_size > 1
else None,
dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
)
attn_metadata = self.metadata_cls(

View File

@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_device.max().item()
max_seq_len = seq_lens_cpu.max().item()
# For Flash Attention MLA + full cudagraph
max_num_splits = 0

View File

@ -92,6 +92,7 @@ class CommonAttentionMetadata:
encoder_seq_lens: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""

View File

@ -1451,9 +1451,12 @@ class GPUModelRunner(
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
]
dcp_local_seq_lens = (
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
)
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
if self.dcp_world_size > 1:
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
spec_decode_common_attn_metadata = None
if for_cudagraph_capture:
@ -1521,6 +1524,7 @@ class GPUModelRunner(
causal=True,
encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=dcp_local_seq_lens,
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
)
if self.speculative_config and spec_decode_common_attn_metadata is None: