diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e92e6c5e2dc8d..9294068c64d1a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.use_naive_attn = True if self.use_naive_attn: - self.attn_func = _naive_attention + self.attn_func = _sdpa_attention logger.debug("Using naive attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -342,11 +342,18 @@ class ROCmFlashAttentionImpl(AttentionImpl): # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention out = self.attn_func( query, key, value, prefill_meta.seq_lens, + num_tokens, + self.num_heads, + self.head_size, self.scale, ) else: @@ -402,45 +409,34 @@ class ROCmFlashAttentionImpl(AttentionImpl): return output.view(num_tokens, hidden_size) -def _naive_attention( +def _sdpa_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, seq_lens: List[int], + num_tokens: int, + num_heads: int, + head_size: int, scale: float, ) -> torch.Tensor: - output = torch.empty_like(query) start = 0 - for _, seq_len in enumerate(seq_lens): + output = torch.empty((num_tokens, num_heads, head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in seq_lens: end = start + seq_len - out = _naive_masked_attention( - query[start:end], - key[start:end], - value[start:end], - scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out) - start += seq_len + with torch.backends.cuda.sdp_kernel(enable_math=True, + enable_flash=False, + enable_mem_efficient=False): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=True, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end return output - - -def _naive_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, -) -> torch.Tensor: - seq_len, head_size, head_dim = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out