mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 08:37:00 +08:00
[ROCm][AMD] Use pytorch sdpa math backend to do naive attention (#4965)
This commit is contained in:
parent
b3376e5c76
commit
c96fc06747
@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.use_naive_attn = True
|
self.use_naive_attn = True
|
||||||
|
|
||||||
if self.use_naive_attn:
|
if self.use_naive_attn:
|
||||||
self.attn_func = _naive_attention
|
self.attn_func = _sdpa_attention
|
||||||
logger.debug("Using naive attention in ROCmBackend")
|
logger.debug("Using naive attention in ROCmBackend")
|
||||||
|
|
||||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
@ -342,11 +342,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
# Interleave for MQA workaround.
|
# Interleave for MQA workaround.
|
||||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||||
|
query = query.movedim(0, query.dim() - 2)
|
||||||
|
key = key.movedim(0, key.dim() - 2)
|
||||||
|
value = value.movedim(0, value.dim() - 2)
|
||||||
|
# sdpa math backend attention
|
||||||
out = self.attn_func(
|
out = self.attn_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
prefill_meta.seq_lens,
|
prefill_meta.seq_lens,
|
||||||
|
num_tokens,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
self.scale,
|
self.scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -402,45 +409,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
def _naive_attention(
|
def _sdpa_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
seq_lens: List[int],
|
seq_lens: List[int],
|
||||||
|
num_tokens: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = torch.empty_like(query)
|
|
||||||
start = 0
|
start = 0
|
||||||
for _, seq_len in enumerate(seq_lens):
|
output = torch.empty((num_tokens, num_heads, head_size),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
|
||||||
|
for seq_len in seq_lens:
|
||||||
end = start + seq_len
|
end = start + seq_len
|
||||||
out = _naive_masked_attention(
|
with torch.backends.cuda.sdp_kernel(enable_math=True,
|
||||||
query[start:end],
|
enable_flash=False,
|
||||||
key[start:end],
|
enable_mem_efficient=False):
|
||||||
value[start:end],
|
sub_out = torch.nn.functional.scaled_dot_product_attention(
|
||||||
scale,
|
query[:, start:end, :],
|
||||||
)
|
key[:, start:end, :],
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
value[:, start:end, :],
|
||||||
output[start:end].copy_(out)
|
dropout_p=0.0,
|
||||||
start += seq_len
|
is_causal=True,
|
||||||
|
scale=scale).movedim(query.dim() - 2, 0)
|
||||||
|
output[start:end, :, :] = sub_out
|
||||||
|
start = end
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _naive_masked_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
seq_len, head_size, head_dim = query.shape
|
|
||||||
attn_mask = torch.triu(torch.ones(seq_len,
|
|
||||||
seq_len,
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device),
|
|
||||||
diagonal=1)
|
|
||||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
|
||||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
|
||||||
attn_weights = attn_weights + attn_mask.float()
|
|
||||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
||||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
|
||||||
return out
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user