[Bugfix] Massage MLA's usage of flash attn for RoCM (#13310)

This commit is contained in:
Tyler Michael Smith 2025-02-15 00:33:25 -05:00 committed by GitHub
parent 579d7a63b2
commit 97a3d6d995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
@ -183,6 +184,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
@ -487,7 +497,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
attn_output = flash_attn_varlen_func(
attn_output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
@ -497,7 +507,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=self.vllm_flash_attn_version,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\