From 000cceca8c329d5b5d99e0186fbd444a390384cd Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 16 Aug 2025 14:16:00 -0400 Subject: [PATCH] [Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016) Signed-off-by: mgoin --- vllm/attention/layer.py | 9 +++++++++ vllm/v1/attention/backends/flashinfer.py | 3 --- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1a9c0e26b53c..0e87fa3f23e3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -308,6 +308,15 @@ class Attention(nn.Module): if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) + # FlashInfer requires attention sinks to be float32 + if (self.backend == _Backend.FLASHINFER_VLLM_V1 + and hasattr(self.impl, 'sinks')): + from vllm.v1.attention.backends.flashinfer import FlashInferImpl + assert isinstance(self.impl, FlashInferImpl) + if (self.impl.sinks is not None + and self.impl.sinks.dtype != torch.float32): + self.impl.sinks = self.impl.sinks.to(torch.float32) + def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index eac3f33e1509..991904229fd7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl): f"heads in the layer. Expected {num_heads}, but got " f"{sinks.shape[0]}." ) - # Cast sinks to float32 if needed (FlashInfer requirement) - if sinks.dtype != torch.float32: - sinks = sinks.to(torch.float32) self.sinks = sinks def forward(