This commit is contained in:
Woosuk Kwon 2024-04-16 08:05:36 +00:00
parent d4adf92beb
commit c59c1e7b2c
4 changed files with 0 additions and 450 deletions

View File

@ -1,146 +0,0 @@
"""Attention layer with Pallas FlashAttention and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from torch_xla.experimental.custom_kernel import flash_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["PallasAttentionImpl"]:
return PallasAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "PallasAttentionMetadata":
return PallasAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError(
"Swapping blocks is not supported on TPU backend.")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
raise NotImplementedError(
"Copying blocks is not supported on TPU backend.")
@dataclass
class PallasAttentionMetadata(AttentionMetadata):
"""Metadata for PallasAttentionBackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
block_tables: torch.Tensor
class PallasAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
if sliding_window is not None:
raise NotImplementedError(
"Sliding window is not supported on TPU backend.")
if alibi_slopes is not None:
raise NotImplementedError(
"Alibi slopes are not supported on TPU backend.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# TODO(woosuk): Check supported head sizes.
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasAttentionMetadata,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache is not None:
key_cache, value_cache = kv_cache[0], kv_cache[1]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache.index_copy_(dim=2,
index=attn_metadata.slot_mapping,
source=key)
value_cache.index_copy_(dim=2,
index=attn_metadata.slot_mapping,
source=value)
if attn_metadata.is_prompt:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
# normal attention
output = flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
causal=True,
)
output = output.permute(0, 2, 1, 3)
output = output.reshape(batch_size, seq_len, hidden_size)
else:
# prefix-enabled attention
raise NotImplementedError(
"Prefix-enabled attention is not supported on TPU backend."
)
else:
# Decoding run.
pass
return output

View File

@ -44,11 +44,6 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.PALLAS:
logger.info("Using PallasAttention backend.")
from vllm.attention.backends.pallas import ( # noqa: F401
PallasAttentionBackend)
return PallasAttentionBackend
else:
raise ValueError("Invalid attention backend.")

View File

@ -1,299 +0,0 @@
"""Inference-only Gemma model compatible with HF weights.
Adapted from
https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py
NOTE(woosuk): This is a temporary workaround to run the Gemma model using
PyTorch XLA. This should be merged into the main Gemma model implementation
once the custom ops are refactored and the model becomes torch.compile-able.
"""
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from transformers import GemmaConfig, PreTrainedModel
from vllm.attention import Attention, AttentionMetadata
class Linear(nn.Module):
"""PyTorch Linear layer without parameter initialization."""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.bias = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
# Reshape the output tensor to the original shape.
return x_out.reshape(x_out.shape[0], x_out.shape[1], -1)
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
x = self._norm(x.float())
output = x * (1 + self.weight.float())
return output.to(orig_dtype)
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = Linear(
hidden_size,
intermediate_size,
bias=False,
)
self.up_proj = Linear(
hidden_size,
intermediate_size,
bias=False,
)
self.down_proj = Linear(
intermediate_size,
hidden_size,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs
class GemmaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.scaling = self.head_dim**-0.5
self.q_proj = Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=False)
self.k_proj = Linear(self.hidden_size,
self.num_kv_heads * self.head_dim,
bias=False)
self.v_proj = Linear(self.hidden_size,
self.num_kv_heads * self.head_dim,
bias=False)
self.o_proj = Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=False)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
q = apply_rotary_emb(q, freqs_cis)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
k = apply_rotary_emb(k, freqs_cis)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
o = self.o_proj(attn_output)
return o
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
):
super().__init__()
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(GemmaDecoderLayer(config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
freqs_cis: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Gemma normalizes the embedding by sqrt(hidden_size).
# FIXME(woosuk): Downcast the normalizer.
hidden_states = hidden_states * (self.config.hidden_size**0.5)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_cache=kv_caches[i],
attn_metadata=attn_metadata,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class GemmaForCausalLM(PreTrainedModel):
def __init__(
self,
config: GemmaConfig,
):
super().__init__(config)
self.config = config
self.model = GemmaModel(config)
rope_theta = getattr(config, 'rope_theta', 10000)
# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
freqs_cis = precompute_freqs_cis(config.head_dim,
config.max_position_embeddings * 2,
theta=rope_theta)
self.register_buffer('freqs_cis', freqs_cis)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
freqs_cis = self.freqs_cis.index_select(0, positions)
hidden_states = self.model(
input_ids=input_ids,
freqs_cis=freqs_cis,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return hidden_states