diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8592d1b26dfa8..caf9ecc91108d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -611,6 +611,7 @@ class FlashInferImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -635,6 +636,15 @@ class FlashInferImpl(AttentionImpl): "are not implemented for " "FlashInferImpl") + self.sinks: Optional[torch.Tensor] = None + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads " + "as the number of heads in the layer" + ) + assert sinks.dtype == torch.float32, "Sinks must be of type float32" + self.sinks = sinks + def forward( self, layer: torch.nn.Module,