[Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-16 14:16:00 -04:00 committed by GitHub
parent 68373d3126
commit 000cceca8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 3 deletions

View File

@ -308,6 +308,15 @@ class Attention(nn.Module):
if hasattr(self.impl, "process_weights_after_loading"): if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype) self.impl.process_weights_after_loading(act_dtype)
# FlashInfer requires attention sinks to be float32
if (self.backend == _Backend.FLASHINFER_VLLM_V1
and hasattr(self.impl, 'sinks')):
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
assert isinstance(self.impl, FlashInferImpl)
if (self.impl.sinks is not None
and self.impl.sinks.dtype != torch.float32):
self.impl.sinks = self.impl.sinks.to(torch.float32)
def get_attn_backend(self) -> type[AttentionBackend]: def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend return self.attn_backend

View File

@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl):
f"heads in the layer. Expected {num_heads}, but got " f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}." f"{sinks.shape[0]}."
) )
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32:
sinks = sinks.to(torch.float32)
self.sinks = sinks self.sinks = sinks
def forward( def forward(