mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 04:09:37 +08:00
[ROCm][AMD] Use pytorch sdpa math backend to do naive attention (#4965)
This commit is contained in:
parent
b3376e5c76
commit
c96fc06747
@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.use_naive_attn = True
|
||||
|
||||
if self.use_naive_attn:
|
||||
self.attn_func = _naive_attention
|
||||
self.attn_func = _sdpa_attention
|
||||
logger.debug("Using naive attention in ROCmBackend")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -342,11 +342,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# Interleave for MQA workaround.
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
# sdpa math backend attention
|
||||
out = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta.seq_lens,
|
||||
num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
)
|
||||
else:
|
||||
@ -402,45 +409,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def _naive_attention(
|
||||
def _sdpa_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
seq_lens: List[int],
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for _, seq_len in enumerate(seq_lens):
|
||||
output = torch.empty((num_tokens, num_heads, head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
for seq_len in seq_lens:
|
||||
end = start + seq_len
|
||||
out = _naive_masked_attention(
|
||||
query[start:end],
|
||||
key[start:end],
|
||||
value[start:end],
|
||||
scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out)
|
||||
start += seq_len
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True,
|
||||
enable_flash=False,
|
||||
enable_mem_efficient=False):
|
||||
sub_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
key[:, start:end, :],
|
||||
value[:, start:end, :],
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
scale=scale).movedim(query.dim() - 2, 0)
|
||||
output[start:end, :, :] = sub_out
|
||||
start = end
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _naive_masked_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
seq_len, head_size, head_dim = query.shape
|
||||
attn_mask = torch.triu(torch.ones(seq_len,
|
||||
seq_len,
|
||||
dtype=query.dtype,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user