diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index bb3dcddba3e9..1f6b7e41b37e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -607,6 +607,7 @@ class FlashAttentionImpl(AttentionImpl): q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, + s_aux=self.sinks, ) return output @@ -767,6 +768,7 @@ def cascade_attention( q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, + s_aux: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. @@ -801,6 +803,9 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + # s_aux is incorporated into prefix_lse inside the GPU kernel, + # enabling its effect during the final attention merge. + s_aux=s_aux, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])