Fix flashattn

This commit is contained in:
Woosuk Kwon 2024-04-10 08:02:40 +00:00
parent 696b653193
commit 363e6a950f

View File

@ -3,10 +3,10 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from torch_xla.experimental.custom_kernel import flash_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata)
from torch_xla.experimental.custom_kernel import flash_attention
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
@ -52,8 +52,8 @@ class PallasAttentionMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
slot_mapping: torch.Tensor
block_tables: torch.Tensor
class PallasAttentionImpl(AttentionImpl): class PallasAttentionImpl(AttentionImpl):
@ -74,12 +74,10 @@ class PallasAttentionImpl(AttentionImpl):
if sliding_window is not None: if sliding_window is not None:
raise NotImplementedError( raise NotImplementedError(
"Sliding window is not supported on TPU backend." "Sliding window is not supported on TPU backend.")
)
if alibi_slopes is not None: if alibi_slopes is not None:
raise NotImplementedError( raise NotImplementedError(
"Alibi slopes are not supported on TPU backend." "Alibi slopes are not supported on TPU backend.")
)
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@ -108,7 +106,8 @@ class PallasAttentionImpl(AttentionImpl):
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size) value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache is not None: if kv_cache is not None:
key_cache, value_cache = kv_cache[0], kv_cache[1] key_cache, value_cache = kv_cache[0], kv_cache[1]
@ -116,10 +115,14 @@ class PallasAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
key_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=key) key_cache.index_copy_(dim=2,
value_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=value) index=attn_metadata.slot_mapping,
source=key)
value_cache.index_copy_(dim=2,
index=attn_metadata.slot_mapping,
source=value)
if True: # FIXME if attn_metadata.is_prompt:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# normal attention # normal attention
@ -130,6 +133,7 @@ class PallasAttentionImpl(AttentionImpl):
causal=True, causal=True,
) )
output = output.permute(0, 2, 1, 3) output = output.permute(0, 2, 1, 3)
output = output.reshape(batch_size, seq_len, hidden_size)
else: else:
# prefix-enabled attention # prefix-enabled attention
raise NotImplementedError( raise NotImplementedError(
@ -139,5 +143,4 @@ class PallasAttentionImpl(AttentionImpl):
# Decoding run. # Decoding run.
pass pass
# Reshape the output tensor. return output
return output.view(batch_size, seq_len, hidden_size)