From 5afd3276dfd70397a08ba16ee8eb246ddb3d13ef Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Thu, 16 Oct 2025 23:02:30 +0800 Subject: [PATCH] [Feature] Add process_weights_after_loading to AttentionImpl (#26870) Signed-off-by: rongfu.leng --- vllm/attention/backends/abstract.py | 3 +++ vllm/attention/layer.py | 11 +---------- vllm/v1/attention/backends/flashinfer.py | 5 +++++ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index fb2db4d0b0ec3..e9c6a278a9411 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -207,6 +207,9 @@ class AttentionImpl(ABC, Generic[T]): """ return False + def process_weights_after_loading(self, act_dtype: torch.dtype): + pass + class MLAAttentionImpl(AttentionImpl[T], Generic[T]): @abstractmethod diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9f879f7272e21..4b591f07ca2d4 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -404,16 +404,7 @@ class Attention(nn.Module, AttentionLayerBase): return s def process_weights_after_loading(self, act_dtype: torch.dtype): - if hasattr(self.impl, "process_weights_after_loading"): - self.impl.process_weights_after_loading(act_dtype) - - # FlashInfer requires attention sinks to be float32 - if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"): - from vllm.v1.attention.backends.flashinfer import FlashInferImpl - - assert isinstance(self.impl, FlashInferImpl) - if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32: - self.impl.sinks = self.impl.sinks.to(torch.float32) + self.impl.process_weights_after_loading(act_dtype) def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index eb9f6a280d8f6..34225602f025c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -833,6 +833,11 @@ class FlashInferImpl(AttentionImpl): return self.support_trtllm_attn + # FlashInfer requires attention sinks to be float32 + def process_weights_after_loading(self, act_dtype: torch.dtype): + if self.sinks is not None and self.sinks.dtype != torch.float32: + self.sinks = self.sinks.to(torch.float32) + def forward( self, layer: torch.nn.Module,