mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-27 17:17:06 +08:00
[WIP] Add Pallas backend
This commit is contained in:
parent
46b31ed98d
commit
02e614d922
143
vllm/attention/backends/pallas.py
Normal file
143
vllm/attention/backends/pallas.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
"""Attention layer with Pallas FlashAttention and PagedAttention."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionMetadata)
|
||||||
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
|
|
||||||
|
|
||||||
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["PallasAttentionImpl"]:
|
||||||
|
return PallasAttentionImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_metadata(*args, **kwargs) -> "PallasAttentionMetadata":
|
||||||
|
return PallasAttentionMetadata(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return (2, num_kv_heads, num_blocks, block_size, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Swapping blocks is not supported on TPU backend.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Copying blocks is not supported on TPU backend.")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PallasAttentionMetadata(AttentionMetadata):
|
||||||
|
"""Metadata for PallasAttentionBackend."""
|
||||||
|
# Currently, input sequences can only contain all prompts
|
||||||
|
# or all decoding. True if all sequences are prompts.
|
||||||
|
is_prompt: bool
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PallasAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
if sliding_window is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Sliding window is not supported on TPU backend."
|
||||||
|
)
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Alibi slopes are not supported on TPU backend."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
# TODO(woosuk): Check supported head sizes.
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: PallasAttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
|
kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
|
# not cached. This happens during the initial memory profiling run.
|
||||||
|
key_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=key)
|
||||||
|
value_cache.index_copy_(dim=2, index=attn_metadata.slot_mapping, source=value)
|
||||||
|
|
||||||
|
if True: # FIXME
|
||||||
|
# Prompt run.
|
||||||
|
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||||
|
# normal attention
|
||||||
|
output = flash_attention(
|
||||||
|
query.permute(0, 2, 1, 3),
|
||||||
|
key.permute(0, 2, 1, 3),
|
||||||
|
value.permute(0, 2, 1, 3),
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
output = output.permute(0, 2, 1, 3)
|
||||||
|
else:
|
||||||
|
# prefix-enabled attention
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Prefix-enabled attention is not supported on TPU backend."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Decoding run.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(batch_size, seq_len, hidden_size)
|
||||||
@ -12,7 +12,12 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||||
if _can_use_flash_attn(dtype):
|
if True:
|
||||||
|
logger.info("Using PallasAttention backend.")
|
||||||
|
from vllm.attention.backends.pallas import ( # noqa: F401
|
||||||
|
PallasAttentionBackend)
|
||||||
|
return PallasAttentionBackend
|
||||||
|
elif _can_use_flash_attn(dtype):
|
||||||
logger.info("Using FlashAttention backend.")
|
logger.info("Using FlashAttention backend.")
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user