[gpt-oss] Enhance error msg on attention sink init (#22335)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Yongye Zhu 2025-08-06 11:41:42 -07:00 committed by GitHub
parent ec7cb19224
commit 31f5dc5b2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -638,11 +638,15 @@ class FlashInferImpl(AttentionImpl):
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads "
"as the number of heads in the layer"
)
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
if sinks.shape[0] != num_heads:
raise ValueError(
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}."
)
if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got "
f"{sinks.dtype}.")
self.sinks = sinks
def forward(