[Bugfix] Massage MLA's usage of flash attn for RoCM (#13310)

This commit is contained in:
Tyler Michael Smith 2025-02-15 00:33:25 -05:00 committed by GitHub
parent 579d7a63b2
commit 97a3d6d995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple from typing import Any, Dict, Generic, List, Optional, Tuple
@ -183,6 +184,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.o_proj = o_proj self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O): if is_fp8(self.W_UV_O):
@ -487,7 +497,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0) value=0)
attn_output = flash_attn_varlen_func( attn_output = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
v=v_padded, v=v_padded,
@ -497,7 +507,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=max_prefill_seq_len, max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
fa_version=self.vllm_flash_attn_version,
) )
attn_output = attn_output\ attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\