Fix flashattn

This commit is contained in:
Woosuk Kwon 2024-04-10 08:02:40 +00:00
parent 696b653193
commit 363e6a950f

View File

@ -3,10 +3,10 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from torch_xla.experimental.custom_kernel import flash_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from torch_xla.experimental.custom_kernel import flash_attention
class PallasAttentionBackend(AttentionBackend):
@ -52,8 +52,8 @@ class PallasAttentionMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
block_tables: torch.Tensor
class PallasAttentionImpl(AttentionImpl):
@ -74,12 +74,10 @@ class PallasAttentionImpl(AttentionImpl):
if sliding_window is not None:
raise NotImplementedError(
"Sliding window is not supported on TPU backend."
)
"Sliding window is not supported on TPU backend.")
if alibi_slopes is not None:
raise NotImplementedError(
"Alibi slopes are not supported on TPU backend."
)
"Alibi slopes are not supported on TPU backend.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@ -108,7 +106,8 @@ class PallasAttentionImpl(AttentionImpl):
# Reshape the query, key, and value tensors.
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache is not None:
key_cache, value_cache = kv_cache[0], kv_cache[1]
@ -116,10 +115,14 @@ class PallasAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=key)
value_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=value)
key_cache.index_copy_(dim=2,
index=attn_metadata.slot_mapping,
source=key)
value_cache.index_copy_(dim=2,
index=attn_metadata.slot_mapping,
source=value)
if True: # FIXME
if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# normal attention
@ -130,6 +133,7 @@ class PallasAttentionImpl(AttentionImpl):
causal=True,
)
output = output.permute(0, 2, 1, 3)
output = output.reshape(batch_size, seq_len, hidden_size)
else:
# prefix-enabled attention
raise NotImplementedError(
@ -139,5 +143,4 @@ class PallasAttentionImpl(AttentionImpl):
# Decoding run.
pass
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
return output