[ROCm][AMD][Model]Adding alibi slopes support in ROCm triton flash attention and naive flash attention (#6043)

Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
This commit is contained in:
Gregory Shtrasberg 2024-07-04 01:19:38 -04:00 committed by GitHub
parent 3dd507083f
commit 56b325e977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -166,6 +166,37 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return self._cached_decode_metadata
def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: Optional[List[int]],
make_attn_mask: bool = True) -> List[torch.Tensor]:
attn_biases = []
if seq_lens:
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat(
(num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
if make_attn_mask:
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))
else:
attn_biases.append(bias.to(dtype))
return attn_biases
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
@ -324,7 +355,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_masks = None
if self.use_triton_flash_attn:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
out, _ = self.attn_func(
query,
key,
@ -336,12 +374,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta.max_prefill_seq_len,
True,
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=True) # type: ignore
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
@ -355,6 +401,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
else:
out = self.attn_func(
@ -418,13 +465,14 @@ def _sdpa_attention(
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)
for seq_len in seq_lens:
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True,
enable_flash=False,
@ -434,7 +482,8 @@ def _sdpa_attention(
key[:, start:end, :],
value[:, start:end, :],
dropout_p=0.0,
is_causal=True,
is_causal=attn_masks is None,
attn_mask=attn_masks[i] if attn_masks else None,
scale=scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end