From db56a59970a84842da2adc3aa64e436f42448b48 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Nov 2025 07:19:22 -0500 Subject: [PATCH] [BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702) --- tests/kernels/attention/test_cascade_flash_attn.py | 1 + vllm/v1/attention/backends/flash_attn.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 4295f852f95bb..20f573821b25f 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -170,6 +170,7 @@ def test_cascade( logits_soft_cap=soft_cap if soft_cap is not None else 0, block_table=block_tables, common_prefix_len=common_prefix_len, + max_num_splits=0, # no max fa_version=fa_version, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 81623549ae850..a5d4435000d4d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -704,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, fa_version=self.vllm_flash_attn_version, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata, @@ -950,6 +951,7 @@ def cascade_attention( logits_soft_cap: float, block_table: torch.Tensor, common_prefix_len: int, + max_num_splits: int, fa_version: int, prefix_scheduler_metadata: torch.Tensor | None = None, suffix_scheduler_metadata: torch.Tensor | None = None, @@ -994,7 +996,7 @@ def cascade_attention( # s_aux is incorporated into prefix_lse inside the GPU kernel, # enabling its effect during the final attention merge. s_aux=s_aux, - num_splits=1 if vllm_is_batch_invariant() else 0, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -1019,7 +1021,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - num_splits=1 if vllm_is_batch_invariant() else 0, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, ) # Merge prefix and suffix outputs, and store the result in output.