mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 02:55:54 +08:00
Fix flashattn
This commit is contained in:
parent
696b653193
commit
363e6a950f
@ -3,10 +3,10 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
@ -52,8 +52,8 @@ class PallasAttentionMetadata(AttentionMetadata):
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
|
||||
class PallasAttentionImpl(AttentionImpl):
|
||||
@ -74,12 +74,10 @@ class PallasAttentionImpl(AttentionImpl):
|
||||
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"Sliding window is not supported on TPU backend."
|
||||
)
|
||||
"Sliding window is not supported on TPU backend.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
"Alibi slopes are not supported on TPU backend."
|
||||
)
|
||||
"Alibi slopes are not supported on TPU backend.")
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -108,7 +106,8 @@ class PallasAttentionImpl(AttentionImpl):
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
@ -116,10 +115,14 @@ class PallasAttentionImpl(AttentionImpl):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
key_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=key)
|
||||
value_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=value)
|
||||
key_cache.index_copy_(dim=2,
|
||||
index=attn_metadata.slot_mapping,
|
||||
source=key)
|
||||
value_cache.index_copy_(dim=2,
|
||||
index=attn_metadata.slot_mapping,
|
||||
source=value)
|
||||
|
||||
if True: # FIXME
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
@ -130,6 +133,7 @@ class PallasAttentionImpl(AttentionImpl):
|
||||
causal=True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
output = output.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise NotImplementedError(
|
||||
@ -139,5 +143,4 @@ class PallasAttentionImpl(AttentionImpl):
|
||||
# Decoding run.
|
||||
pass
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user