mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 (#27128)
Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com>
This commit is contained in:
parent
8669c69afa
commit
5beacce2ea
@ -89,10 +89,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
self.use_full_cuda_graph = (
|
self.use_full_cuda_graph = (
|
||||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
)
|
)
|
||||||
|
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||||
|
|
||||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||||
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
|
||||||
|
|
||||||
if self.max_cudagraph_size > 992:
|
if self.max_cudagraph_size > 992:
|
||||||
# This condition derives from FA3's internal heuristic.
|
# This condition derives from FA3's internal heuristic.
|
||||||
# TODO(woosuk): Support larger cudagraph sizes.
|
# TODO(woosuk): Support larger cudagraph sizes.
|
||||||
@ -114,7 +113,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
self.max_num_splits = 1
|
self.max_num_splits = 1
|
||||||
|
|
||||||
def _schedule_decode(
|
def _schedule_decode(
|
||||||
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
self,
|
||||||
|
num_reqs,
|
||||||
|
cu_query_lens,
|
||||||
|
max_query_len,
|
||||||
|
seqlens,
|
||||||
|
max_seq_len,
|
||||||
|
causal,
|
||||||
|
max_num_splits,
|
||||||
):
|
):
|
||||||
if self.fa_aot_schedule:
|
if self.fa_aot_schedule:
|
||||||
return get_scheduler_metadata(
|
return get_scheduler_metadata(
|
||||||
@ -130,7 +136,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
cu_seqlens_q=cu_query_lens,
|
cu_seqlens_q=cu_query_lens,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
num_splits=self.max_num_splits,
|
num_splits=max_num_splits,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -148,6 +154,15 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
max_query_len = query_lens_cpu.max().item()
|
max_query_len = query_lens_cpu.max().item()
|
||||||
max_seq_len = seq_lens_device.max().item()
|
max_seq_len = seq_lens_device.max().item()
|
||||||
|
|
||||||
|
# For Flash Attention MLA + full cudagraph
|
||||||
|
max_num_splits = 0
|
||||||
|
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
|
||||||
|
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||||
|
# usage, because the intermediate buffers of size [num_splits,
|
||||||
|
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||||
|
# we only set num_splits when using cuda graphs.
|
||||||
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
scheduler_metadata = self._schedule_decode(
|
scheduler_metadata = self._schedule_decode(
|
||||||
num_reqs=seq_lens_cpu.numel(),
|
num_reqs=seq_lens_cpu.numel(),
|
||||||
cu_query_lens=query_start_loc_device,
|
cu_query_lens=query_start_loc_device,
|
||||||
@ -155,10 +170,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
seqlens=seq_lens_device,
|
seqlens=seq_lens_device,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
max_num_splits=max_num_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For FA3 + full cudagraph
|
|
||||||
max_num_splits = 0
|
|
||||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||||
n = scheduler_metadata.shape[0]
|
n = scheduler_metadata.shape[0]
|
||||||
# Ensure the persistent buffer is large enough
|
# Ensure the persistent buffer is large enough
|
||||||
@ -174,13 +188,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
self.scheduler_metadata[n:] = 0
|
self.scheduler_metadata[n:] = 0
|
||||||
scheduler_metadata = self.scheduler_metadata[:n]
|
scheduler_metadata = self.scheduler_metadata[:n]
|
||||||
|
|
||||||
if num_decode_tokens <= self.max_cudagraph_size:
|
|
||||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
|
||||||
# usage, because the intermediate buffers of size [num_splits,
|
|
||||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
|
||||||
# we only set num_splits when using cuda graphs.
|
|
||||||
max_num_splits = self.max_num_splits
|
|
||||||
|
|
||||||
if vllm_is_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
max_num_splits = 1
|
max_num_splits = 1
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user