mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 04:37:03 +08:00
Remove
This commit is contained in:
parent
d4adf92beb
commit
c59c1e7b2c
@ -1,146 +0,0 @@
|
|||||||
"""Attention layer with Pallas FlashAttention and PagedAttention."""
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch_xla.experimental.custom_kernel import flash_attention
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
||||||
AttentionMetadata)
|
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_impl_cls() -> Type["PallasAttentionImpl"]:
|
|
||||||
return PallasAttentionImpl
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make_metadata(*args, **kwargs) -> "PallasAttentionMetadata":
|
|
||||||
return PallasAttentionMetadata(*args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_kv_cache_shape(
|
|
||||||
num_blocks: int,
|
|
||||||
block_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
) -> Tuple[int, ...]:
|
|
||||||
return (2, num_kv_heads, num_blocks, block_size, head_size)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def swap_blocks(
|
|
||||||
src_kv_cache: torch.Tensor,
|
|
||||||
dst_kv_cache: torch.Tensor,
|
|
||||||
src_to_dst: Dict[int, int],
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Swapping blocks is not supported on TPU backend.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def copy_blocks(
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
src_to_dists: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Copying blocks is not supported on TPU backend.")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PallasAttentionMetadata(AttentionMetadata):
|
|
||||||
"""Metadata for PallasAttentionBackend."""
|
|
||||||
# Currently, input sequences can only contain all prompts
|
|
||||||
# or all decoding. True if all sequences are prompts.
|
|
||||||
is_prompt: bool
|
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
block_tables: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionImpl(AttentionImpl):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
scale: float,
|
|
||||||
num_kv_heads: Optional[int] = None,
|
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_size = head_size
|
|
||||||
self.scale = float(scale)
|
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
||||||
|
|
||||||
if sliding_window is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Sliding window is not supported on TPU backend.")
|
|
||||||
if alibi_slopes is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Alibi slopes are not supported on TPU backend.")
|
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
# TODO(woosuk): Check supported head sizes.
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: PallasAttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
|
||||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
|
||||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
|
||||||
kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size]
|
|
||||||
attn_metadata: Metadata for attention.
|
|
||||||
Returns:
|
|
||||||
shape = [batch_size, seq_len, num_heads * head_size]
|
|
||||||
"""
|
|
||||||
batch_size, seq_len, hidden_size = query.shape
|
|
||||||
# Reshape the query, key, and value tensors.
|
|
||||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
|
||||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
|
||||||
self.head_size)
|
|
||||||
|
|
||||||
if kv_cache is not None:
|
|
||||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# If kv_cache is not provided, the new key and value tensors are
|
|
||||||
# not cached. This happens during the initial memory profiling run.
|
|
||||||
key_cache.index_copy_(dim=2,
|
|
||||||
index=attn_metadata.slot_mapping,
|
|
||||||
source=key)
|
|
||||||
value_cache.index_copy_(dim=2,
|
|
||||||
index=attn_metadata.slot_mapping,
|
|
||||||
source=value)
|
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
|
||||||
# Prompt run.
|
|
||||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
|
||||||
# normal attention
|
|
||||||
output = flash_attention(
|
|
||||||
query.permute(0, 2, 1, 3),
|
|
||||||
key.permute(0, 2, 1, 3),
|
|
||||||
value.permute(0, 2, 1, 3),
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = output.permute(0, 2, 1, 3)
|
|
||||||
output = output.reshape(batch_size, seq_len, hidden_size)
|
|
||||||
else:
|
|
||||||
# prefix-enabled attention
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Prefix-enabled attention is not supported on TPU backend."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Decoding run.
|
|
||||||
pass
|
|
||||||
|
|
||||||
return output
|
|
||||||
@ -44,11 +44,6 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
|||||||
logger.info("Using Torch SDPA backend.")
|
logger.info("Using Torch SDPA backend.")
|
||||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||||
return TorchSDPABackend
|
return TorchSDPABackend
|
||||||
elif backend == _Backend.PALLAS:
|
|
||||||
logger.info("Using PallasAttention backend.")
|
|
||||||
from vllm.attention.backends.pallas import ( # noqa: F401
|
|
||||||
PallasAttentionBackend)
|
|
||||||
return PallasAttentionBackend
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid attention backend.")
|
raise ValueError("Invalid attention backend.")
|
||||||
|
|
||||||
|
|||||||
@ -1,299 +0,0 @@
|
|||||||
"""Inference-only Gemma model compatible with HF weights.
|
|
||||||
|
|
||||||
Adapted from
|
|
||||||
https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py
|
|
||||||
|
|
||||||
NOTE(woosuk): This is a temporary workaround to run the Gemma model using
|
|
||||||
PyTorch XLA. This should be merged into the main Gemma model implementation
|
|
||||||
once the custom ops are refactored and the model becomes torch.compile-able.
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
from transformers import GemmaConfig, PreTrainedModel
|
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
|
||||||
"""PyTorch Linear layer without parameter initialization."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
bias: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.empty(out_features))
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return F.linear(x, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
|
||||||
"""Precomputes the frequency cis."""
|
|
||||||
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
|
||||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
|
||||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Applies the rotary embedding to the query and key tensors."""
|
|
||||||
x_ = torch.view_as_complex(
|
|
||||||
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
|
|
||||||
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
|
|
||||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
|
||||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
|
||||||
-1).transpose(1, 2)
|
|
||||||
# Reshape the output tensor to the original shape.
|
|
||||||
return x_out.reshape(x_out.shape[0], x_out.shape[1], -1)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
x = self._norm(x.float())
|
|
||||||
output = x * (1 + self.weight.float())
|
|
||||||
return output.to(orig_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
|
|
||||||
self.gate_proj = Linear(
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.up_proj = Linear(
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.down_proj = Linear(
|
|
||||||
intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
gate = self.gate_proj(x)
|
|
||||||
gate = F.gelu(gate, approximate="tanh")
|
|
||||||
up = self.up_proj(x)
|
|
||||||
fuse = gate * up
|
|
||||||
outputs = self.down_proj(fuse)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.scaling = self.head_dim**-0.5
|
|
||||||
|
|
||||||
self.q_proj = Linear(self.hidden_size,
|
|
||||||
self.num_heads * self.head_dim,
|
|
||||||
bias=False)
|
|
||||||
self.k_proj = Linear(self.hidden_size,
|
|
||||||
self.num_kv_heads * self.head_dim,
|
|
||||||
bias=False)
|
|
||||||
self.v_proj = Linear(self.hidden_size,
|
|
||||||
self.num_kv_heads * self.head_dim,
|
|
||||||
bias=False)
|
|
||||||
self.o_proj = Linear(self.num_heads * self.head_dim,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=False)
|
|
||||||
self.attn = Attention(self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.scaling,
|
|
||||||
num_kv_heads=self.num_kv_heads)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size, seq_len, _ = hidden_states.shape
|
|
||||||
q = self.q_proj(hidden_states)
|
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
||||||
q = apply_rotary_emb(q, freqs_cis)
|
|
||||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
|
||||||
k = apply_rotary_emb(k, freqs_cis)
|
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
||||||
o = self.o_proj(attn_output)
|
|
||||||
return o
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: GemmaConfig,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = GemmaAttention(
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
num_kv_heads=config.num_key_value_heads,
|
|
||||||
head_dim=config.head_dim,
|
|
||||||
)
|
|
||||||
self.mlp = GemmaMLP(
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
)
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Self Attention
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
hidden_states = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
freqs_cis=freqs_cis,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# MLP
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: GemmaConfig,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for _ in range(config.num_hidden_layers):
|
|
||||||
self.layers.append(GemmaDecoderLayer(config))
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
# Gemma normalizes the embedding by sqrt(hidden_size).
|
|
||||||
# FIXME(woosuk): Downcast the normalizer.
|
|
||||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
|
||||||
|
|
||||||
for i in range(len(self.layers)):
|
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states = layer(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
freqs_cis=freqs_cis,
|
|
||||||
kv_cache=kv_caches[i],
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaForCausalLM(PreTrainedModel):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: GemmaConfig,
|
|
||||||
):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.model = GemmaModel(config)
|
|
||||||
rope_theta = getattr(config, 'rope_theta', 10000)
|
|
||||||
# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
|
|
||||||
freqs_cis = precompute_freqs_cis(config.head_dim,
|
|
||||||
config.max_position_embeddings * 2,
|
|
||||||
theta=rope_theta)
|
|
||||||
self.register_buffer('freqs_cis', freqs_cis)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
freqs_cis = self.freqs_cis.index_select(0, positions)
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
freqs_cis=freqs_cis,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
Loading…
x
Reference in New Issue
Block a user