[ROCm][AMD] Use pytorch sdpa math backend to do naive attention (#4965)

This commit is contained in:
Hongxia Yang 2024-06-07 22:13:12 -04:00 committed by GitHub
parent b3376e5c76
commit c96fc06747
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = True
if self.use_naive_attn:
self.attn_func = _naive_attention
self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -342,11 +342,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
num_tokens,
self.num_heads,
self.head_size,
self.scale,
)
else:
@ -402,45 +409,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
return output.view(num_tokens, hidden_size)
def _naive_attention(
def _sdpa_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
seq_lens: List[int],
num_tokens: int,
num_heads: int,
head_size: int,
scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, seq_len in enumerate(seq_lens):
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)
for seq_len in seq_lens:
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True,
enable_flash=False,
enable_mem_efficient=False):
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
value[:, start:end, :],
dropout_p=0.0,
is_causal=True,
scale=scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
return output
def _naive_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out