mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 20:05:01 +08:00
[gpt-oss] flashinfer attention sink init (#22330)
Signed-off-by: simon-mo <xmo@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
This commit is contained in:
parent
a47e6ffe93
commit
90ec006937
@ -611,6 +611,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -635,6 +636,15 @@ class FlashInferImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
self.sinks: Optional[torch.Tensor] = None
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads "
|
||||
"as the number of heads in the layer"
|
||||
)
|
||||
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user