[Feature] Add process_weights_after_loading to AttentionImpl (#26870)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng 2025-10-16 23:02:30 +08:00 committed by GitHub
parent 43721bc67f
commit 5afd3276df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 10 deletions

View File

@ -207,6 +207,9 @@ class AttentionImpl(ABC, Generic[T]):
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod

View File

@ -404,16 +404,7 @@ class Attention(nn.Module, AttentionLayerBase):
return s
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
# FlashInfer requires attention sinks to be float32
if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"):
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
assert isinstance(self.impl, FlashInferImpl)
if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32:
self.impl.sinks = self.impl.sinks.to(torch.float32)
self.impl.process_weights_after_loading(act_dtype)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend

View File

@ -833,6 +833,11 @@ class FlashInferImpl(AttentionImpl):
return self.support_trtllm_attn
# FlashInfer requires attention sinks to be float32
def process_weights_after_loading(self, act_dtype: torch.dtype):
if self.sinks is not None and self.sinks.dtype != torch.float32:
self.sinks = self.sinks.to(torch.float32)
def forward(
self,
layer: torch.nn.Module,