mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 09:25:45 +08:00
[gpt-oss] Enhance error msg on attention sink init (#22335)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
parent
ec7cb19224
commit
31f5dc5b2a
@ -638,11 +638,15 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.sinks: Optional[torch.Tensor] = None
|
self.sinks: Optional[torch.Tensor] = None
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
assert sinks.shape[0] == num_heads, (
|
if sinks.shape[0] != num_heads:
|
||||||
"Sinks must have the same number of heads "
|
raise ValueError(
|
||||||
"as the number of heads in the layer"
|
"Sinks must have the same number of heads as the number of "
|
||||||
)
|
f"heads in the layer. Expected {num_heads}, but got "
|
||||||
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
|
f"{sinks.shape[0]}."
|
||||||
|
)
|
||||||
|
if sinks.dtype != torch.float32:
|
||||||
|
raise ValueError("Sinks must be of type float32, but got "
|
||||||
|
f"{sinks.dtype}.")
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user