mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-01 21:50:09 +08:00
[Feature] Add process_weights_after_loading to AttentionImpl (#26870)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
43721bc67f
commit
5afd3276df
@ -207,6 +207,9 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -404,16 +404,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
if hasattr(self.impl, "process_weights_after_loading"):
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
self.impl.process_weights_after_loading(act_dtype)
|
|
||||||
|
|
||||||
# FlashInfer requires attention sinks to be float32
|
|
||||||
if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"):
|
|
||||||
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
|
||||||
|
|
||||||
assert isinstance(self.impl, FlashInferImpl)
|
|
||||||
if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32:
|
|
||||||
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
|
||||||
|
|
||||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
return self.attn_backend
|
return self.attn_backend
|
||||||
|
|||||||
@ -833,6 +833,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
return self.support_trtllm_attn
|
return self.support_trtllm_attn
|
||||||
|
|
||||||
|
# FlashInfer requires attention sinks to be float32
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
if self.sinks is not None and self.sinks.dtype != torch.float32:
|
||||||
|
self.sinks = self.sinks.to(torch.float32)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user