mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 15:24:36 +08:00
[BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702)
This commit is contained in:
parent
9324e10275
commit
db56a59970
@ -170,6 +170,7 @@ def test_cascade(
|
|||||||
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
|
max_num_splits=0, # no max
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -704,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap=self.logits_soft_cap,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
common_prefix_len=attn_metadata.common_prefix_len,
|
common_prefix_len=attn_metadata.common_prefix_len,
|
||||||
|
max_num_splits=attn_metadata.max_num_splits,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||||
@ -950,6 +951,7 @@ def cascade_attention(
|
|||||||
logits_soft_cap: float,
|
logits_soft_cap: float,
|
||||||
block_table: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
max_num_splits: int,
|
||||||
fa_version: int,
|
fa_version: int,
|
||||||
prefix_scheduler_metadata: torch.Tensor | None = None,
|
prefix_scheduler_metadata: torch.Tensor | None = None,
|
||||||
suffix_scheduler_metadata: torch.Tensor | None = None,
|
suffix_scheduler_metadata: torch.Tensor | None = None,
|
||||||
@ -994,7 +996,7 @@ def cascade_attention(
|
|||||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||||
# enabling its effect during the final attention merge.
|
# enabling its effect during the final attention merge.
|
||||||
s_aux=s_aux,
|
s_aux=s_aux,
|
||||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||||
@ -1019,7 +1021,7 @@ def cascade_attention(
|
|||||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge prefix and suffix outputs, and store the result in output.
|
# Merge prefix and suffix outputs, and store the result in output.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user