From 363e6a950f2f14839b6d11d8d0094889bf500f81 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Apr 2024 08:02:40 +0000 Subject: [PATCH] Fix flashattn --- vllm/attention/backends/pallas.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 242a50f9cf7a7..80faac2ec63fc 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch +from torch_xla.experimental.custom_kernel import flash_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from torch_xla.experimental.custom_kernel import flash_attention class PallasAttentionBackend(AttentionBackend): @@ -52,8 +52,8 @@ class PallasAttentionMetadata(AttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - - + slot_mapping: torch.Tensor + block_tables: torch.Tensor class PallasAttentionImpl(AttentionImpl): @@ -74,12 +74,10 @@ class PallasAttentionImpl(AttentionImpl): if sliding_window is not None: raise NotImplementedError( - "Sliding window is not supported on TPU backend." - ) + "Sliding window is not supported on TPU backend.") if alibi_slopes is not None: raise NotImplementedError( - "Alibi slopes are not supported on TPU backend." - ) + "Alibi slopes are not supported on TPU backend.") assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -108,7 +106,8 @@ class PallasAttentionImpl(AttentionImpl): # Reshape the query, key, and value tensors. query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) - value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) if kv_cache is not None: key_cache, value_cache = kv_cache[0], kv_cache[1] @@ -116,10 +115,14 @@ class PallasAttentionImpl(AttentionImpl): # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - key_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=key) - value_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=value) + key_cache.index_copy_(dim=2, + index=attn_metadata.slot_mapping, + source=key) + value_cache.index_copy_(dim=2, + index=attn_metadata.slot_mapping, + source=value) - if True: # FIXME + if attn_metadata.is_prompt: # Prompt run. if kv_cache is None or attn_metadata.block_tables.numel() == 0: # normal attention @@ -130,6 +133,7 @@ class PallasAttentionImpl(AttentionImpl): causal=True, ) output = output.permute(0, 2, 1, 3) + output = output.reshape(batch_size, seq_len, hidden_size) else: # prefix-enabled attention raise NotImplementedError( @@ -139,5 +143,4 @@ class PallasAttentionImpl(AttentionImpl): # Decoding run. pass - # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output