mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 02:26:32 +08:00
[BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702)
This commit is contained in:
parent
9324e10275
commit
db56a59970
@ -170,6 +170,7 @@ def test_cascade(
|
||||
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
||||
block_table=block_tables,
|
||||
common_prefix_len=common_prefix_len,
|
||||
max_num_splits=0, # no max
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
|
||||
@ -704,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
max_num_splits=attn_metadata.max_num_splits,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
@ -950,6 +951,7 @@ def cascade_attention(
|
||||
logits_soft_cap: float,
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
max_num_splits: int,
|
||||
fa_version: int,
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None,
|
||||
suffix_scheduler_metadata: torch.Tensor | None = None,
|
||||
@ -994,7 +996,7 @@ def cascade_attention(
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
s_aux=s_aux,
|
||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
@ -1019,7 +1021,7 @@ def cascade_attention(
|
||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user