mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:45:36 +08:00
[BugFix] Fix async-scheduling + FlashAttn MLA (#28990)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
1ffe934c8a
commit
48fc8b1e59
@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
|
||||||
|
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
|
||||||
|
|
||||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
|
||||||
@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
|
dcp_tot_seq_lens_device = None
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
dcp_tot_seq_lens_device = seq_lens[:num_decodes]
|
||||||
|
seq_lens_cpu = dcp_local_seq_lens_cpu
|
||||||
|
seq_lens = dcp_local_seq_lens
|
||||||
|
|
||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||||
seq_lens_device=dcp_local_seq_lens[:num_decodes]
|
seq_lens_device=seq_lens[:num_decodes],
|
||||||
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
|
|
||||||
else seq_lens[:num_decodes],
|
|
||||||
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
|
||||||
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
query_start_loc_device=query_start_loc[: num_decodes + 1],
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
|
dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
|
||||||
if self.dcp_world_size > 1
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata = self.metadata_cls(
|
attn_metadata = self.metadata_cls(
|
||||||
|
|||||||
@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
) -> FlashAttnMLADecodeMetadata:
|
) -> FlashAttnMLADecodeMetadata:
|
||||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
max_query_len = query_lens_cpu.max().item()
|
max_query_len = query_lens_cpu.max().item()
|
||||||
max_seq_len = seq_lens_device.max().item()
|
max_seq_len = seq_lens_cpu.max().item()
|
||||||
|
|
||||||
# For Flash Attention MLA + full cudagraph
|
# For Flash Attention MLA + full cudagraph
|
||||||
max_num_splits = 0
|
max_num_splits = 0
|
||||||
|
|||||||
@ -92,6 +92,7 @@ class CommonAttentionMetadata:
|
|||||||
encoder_seq_lens: np.ndarray | None = None
|
encoder_seq_lens: np.ndarray | None = None
|
||||||
|
|
||||||
dcp_local_seq_lens: torch.Tensor | None = None
|
dcp_local_seq_lens: torch.Tensor | None = None
|
||||||
|
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1451,9 +1451,12 @@ class GPUModelRunner(
|
|||||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||||
:num_reqs
|
:num_reqs
|
||||||
]
|
]
|
||||||
dcp_local_seq_lens = (
|
|
||||||
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
|
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
|
||||||
)
|
if self.dcp_world_size > 1:
|
||||||
|
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||||
|
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
|
||||||
|
|
||||||
spec_decode_common_attn_metadata = None
|
spec_decode_common_attn_metadata = None
|
||||||
|
|
||||||
if for_cudagraph_capture:
|
if for_cudagraph_capture:
|
||||||
@ -1521,6 +1524,7 @@ class GPUModelRunner(
|
|||||||
causal=True,
|
causal=True,
|
||||||
encoder_seq_lens=encoder_seq_lens,
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||||
|
dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user