mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
68373d3126
commit
000cceca8c
@ -308,6 +308,15 @@ class Attention(nn.Module):
|
|||||||
if hasattr(self.impl, "process_weights_after_loading"):
|
if hasattr(self.impl, "process_weights_after_loading"):
|
||||||
self.impl.process_weights_after_loading(act_dtype)
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
# FlashInfer requires attention sinks to be float32
|
||||||
|
if (self.backend == _Backend.FLASHINFER_VLLM_V1
|
||||||
|
and hasattr(self.impl, 'sinks')):
|
||||||
|
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||||
|
assert isinstance(self.impl, FlashInferImpl)
|
||||||
|
if (self.impl.sinks is not None
|
||||||
|
and self.impl.sinks.dtype != torch.float32):
|
||||||
|
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||||
|
|
||||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
return self.attn_backend
|
return self.attn_backend
|
||||||
|
|
||||||
|
|||||||
@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
f"heads in the layer. Expected {num_heads}, but got "
|
f"heads in the layer. Expected {num_heads}, but got "
|
||||||
f"{sinks.shape[0]}."
|
f"{sinks.shape[0]}."
|
||||||
)
|
)
|
||||||
# Cast sinks to float32 if needed (FlashInfer requirement)
|
|
||||||
if sinks.dtype != torch.float32:
|
|
||||||
sinks = sinks.to(torch.float32)
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user