[BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702)

This commit is contained in:
Lucas Wilkinson 2025-11-14 07:19:22 -05:00 committed by GitHub
parent 9324e10275
commit db56a59970
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View File

@ -170,6 +170,7 @@ def test_cascade(
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
max_num_splits=0, # no max
fa_version=fa_version,
)

View File

@ -704,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
@ -950,6 +951,7 @@ def cascade_attention(
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
max_num_splits: int,
fa_version: int,
prefix_scheduler_metadata: torch.Tensor | None = None,
suffix_scheduler_metadata: torch.Tensor | None = None,
@ -994,7 +996,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@ -1019,7 +1021,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.