mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[BugFix] Fix async-scheduling + FlashAttn MLA (#28990)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
1ffe934c8a
commit
48fc8b1e59
@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
dcp_tot_seq_lens_device = None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
||||
seq_lens_cpu = dcp_local_seq_lens_cpu
|
||||
seq_lens = dcp_local_seq_lens
|
||||
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||
seq_lens_device=dcp_local_seq_lens[:num_decodes]
|
||||
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
|
||||
else seq_lens[:num_decodes],
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
|
||||
if self.dcp_world_size > 1
|
||||
else None,
|
||||
dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
|
||||
@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_device.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
|
||||
@ -92,6 +92,7 @@ class CommonAttentionMetadata:
|
||||
encoder_seq_lens: np.ndarray | None = None
|
||||
|
||||
dcp_local_seq_lens: torch.Tensor | None = None
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
|
||||
|
||||
@ -1451,9 +1451,12 @@ class GPUModelRunner(
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs
|
||||
]
|
||||
dcp_local_seq_lens = (
|
||||
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
|
||||
)
|
||||
|
||||
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
|
||||
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
if for_cudagraph_capture:
|
||||
@ -1521,6 +1524,7 @@ class GPUModelRunner(
|
||||
causal=True,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
|
||||
)
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user