mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:35:24 +08:00
[Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
68373d3126
commit
000cceca8c
@ -308,6 +308,15 @@ class Attention(nn.Module):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
if (self.backend == _Backend.FLASHINFER_VLLM_V1
|
||||
and hasattr(self.impl, 'sinks')):
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||
assert isinstance(self.impl, FlashInferImpl)
|
||||
if (self.impl.sinks is not None
|
||||
and self.impl.sinks.dtype != torch.float32):
|
||||
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
f"heads in the layer. Expected {num_heads}, but got "
|
||||
f"{sinks.shape[0]}."
|
||||
)
|
||||
# Cast sinks to float32 if needed (FlashInfer requirement)
|
||||
if sinks.dtype != torch.float32:
|
||||
sinks = sinks.to(torch.float32)
|
||||
self.sinks = sinks
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user