mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:45:01 +08:00
[Bugfix] Massage MLA's usage of flash attn for RoCM (#13310)
This commit is contained in:
parent
579d7a63b2
commit
97a3d6d995
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
@ -183,6 +184,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_UV_O):
|
||||
@ -487,7 +497,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
@ -497,7 +507,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
max_seqlen_k=max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
attn_output = attn_output\
|
||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user