mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 08:54:30 +08:00
[Feature] Add process_weights_after_loading to AttentionImpl (#26870)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
43721bc67f
commit
5afd3276df
@ -207,6 +207,9 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
"""
|
||||
return False
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
@abstractmethod
|
||||
|
||||
@ -404,16 +404,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"):
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||
|
||||
assert isinstance(self.impl, FlashInferImpl)
|
||||
if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32:
|
||||
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
@ -833,6 +833,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
return self.support_trtllm_attn
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if self.sinks is not None and self.sinks.dtype != torch.float32:
|
||||
self.sinks = self.sinks.to(torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user