mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 00:44:06 +08:00
[ROCm][AMD][Model]Adding alibi slopes support in ROCm triton flash attention and naive flash attention (#6043)
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
This commit is contained in:
parent
3dd507083f
commit
56b325e977
@ -166,6 +166,37 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
def _make_alibi_bias(alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: Optional[List[int]],
|
||||
make_attn_mask: bool = True) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
if seq_lens:
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat(
|
||||
(num_heads, 1, 1)).to(alibi_slopes.device)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if make_attn_mask:
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
|
||||
alibi_slopes.device)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
else:
|
||||
attn_biases.append(bias.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
@ -324,7 +355,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
attn_masks = None
|
||||
if self.use_triton_flash_attn:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=False) # type: ignore
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
@ -336,12 +374,20 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
True,
|
||||
self.scale,
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
)
|
||||
elif self.use_naive_attn:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Interleave for MQA workaround.
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=True) # type: ignore
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
@ -355,6 +401,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
attn_masks,
|
||||
)
|
||||
else:
|
||||
out = self.attn_func(
|
||||
@ -418,13 +465,14 @@ def _sdpa_attention(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attn_masks: Optional[List[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
start = 0
|
||||
output = torch.empty((num_tokens, num_heads, head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
for seq_len in seq_lens:
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True,
|
||||
enable_flash=False,
|
||||
@ -434,7 +482,8 @@ def _sdpa_attention(
|
||||
key[:, start:end, :],
|
||||
value[:, start:end, :],
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
is_causal=attn_masks is None,
|
||||
attn_mask=attn_masks[i] if attn_masks else None,
|
||||
scale=scale).movedim(query.dim() - 2, 0)
|
||||
output[start:end, :, :] = sub_out
|
||||
start = end
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user