mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 20:27:08 +08:00
Fix flashattn
This commit is contained in:
parent
696b653193
commit
363e6a950f
@ -3,10 +3,10 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata)
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
@ -52,8 +52,8 @@ class PallasAttentionMetadata(AttentionMetadata):
|
|||||||
# Currently, input sequences can only contain all prompts
|
# Currently, input sequences can only contain all prompts
|
||||||
# or all decoding. True if all sequences are prompts.
|
# or all decoding. True if all sequences are prompts.
|
||||||
is_prompt: bool
|
is_prompt: bool
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
block_tables: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionImpl(AttentionImpl):
|
class PallasAttentionImpl(AttentionImpl):
|
||||||
@ -74,12 +74,10 @@ class PallasAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Sliding window is not supported on TPU backend."
|
"Sliding window is not supported on TPU backend.")
|
||||||
)
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Alibi slopes are not supported on TPU backend."
|
"Alibi slopes are not supported on TPU backend.")
|
||||||
)
|
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -108,7 +106,8 @@ class PallasAttentionImpl(AttentionImpl):
|
|||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||||
|
self.head_size)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
@ -116,10 +115,14 @@ class PallasAttentionImpl(AttentionImpl):
|
|||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
# not cached. This happens during the initial memory profiling run.
|
# not cached. This happens during the initial memory profiling run.
|
||||||
key_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=key)
|
key_cache.index_copy_(dim=2,
|
||||||
value_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=value)
|
index=attn_metadata.slot_mapping,
|
||||||
|
source=key)
|
||||||
|
value_cache.index_copy_(dim=2,
|
||||||
|
index=attn_metadata.slot_mapping,
|
||||||
|
source=value)
|
||||||
|
|
||||||
if True: # FIXME
|
if attn_metadata.is_prompt:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||||
# normal attention
|
# normal attention
|
||||||
@ -130,6 +133,7 @@ class PallasAttentionImpl(AttentionImpl):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
output = output.permute(0, 2, 1, 3)
|
output = output.permute(0, 2, 1, 3)
|
||||||
|
output = output.reshape(batch_size, seq_len, hidden_size)
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -139,5 +143,4 @@ class PallasAttentionImpl(AttentionImpl):
|
|||||||
# Decoding run.
|
# Decoding run.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Reshape the output tensor.
|
return output
|
||||||
return output.view(batch_size, seq_len, hidden_size)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user