[gpt-oss] flashinfer attention sink init (#22330)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
This commit is contained in:
Yongye Zhu 2025-08-05 23:48:19 -07:00 committed by GitHub
parent a47e6ffe93
commit 90ec006937
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -611,6 +611,7 @@ class FlashInferImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -635,6 +636,15 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for "
"FlashInferImpl")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads "
"as the number of heads in the layer"
)
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
self.sinks = sinks
def forward(
self,
layer: torch.nn.Module,