mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 08:12:22 +08:00
Remove
This commit is contained in:
parent
d4adf92beb
commit
c59c1e7b2c
@ -1,146 +0,0 @@
|
||||
"""Attention layer with Pallas FlashAttention and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionImpl"]:
|
||||
return PallasAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "PallasAttentionMetadata":
|
||||
return PallasAttentionMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
"Swapping blocks is not supported on TPU backend.")
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
"Copying blocks is not supported on TPU backend.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasAttentionMetadata(AttentionMetadata):
|
||||
"""Metadata for PallasAttentionBackend."""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
|
||||
class PallasAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"Sliding window is not supported on TPU backend.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
"Alibi slopes are not supported on TPU backend.")
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
# TODO(woosuk): Check supported head sizes.
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache is not None:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
key_cache.index_copy_(dim=2,
|
||||
index=attn_metadata.slot_mapping,
|
||||
source=key)
|
||||
value_cache.index_copy_(dim=2,
|
||||
index=attn_metadata.slot_mapping,
|
||||
source=value)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
output = flash_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
causal=True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
output = output.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise NotImplementedError(
|
||||
"Prefix-enabled attention is not supported on TPU backend."
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
pass
|
||||
|
||||
return output
|
||||
@ -44,11 +44,6 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
logger.info("Using PallasAttention backend.")
|
||||
from vllm.attention.backends.pallas import ( # noqa: F401
|
||||
PallasAttentionBackend)
|
||||
return PallasAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
@ -1,299 +0,0 @@
|
||||
"""Inference-only Gemma model compatible with HF weights.
|
||||
|
||||
Adapted from
|
||||
https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py
|
||||
|
||||
NOTE(woosuk): This is a temporary workaround to run the Gemma model using
|
||||
PyTorch XLA. This should be merged into the main Gemma model implementation
|
||||
once the custom ops are refactored and the model becomes torch.compile-able.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GemmaConfig, PreTrainedModel
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""PyTorch Linear layer without parameter initialization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
"""Precomputes the frequency cis."""
|
||||
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the rotary embedding to the query and key tensors."""
|
||||
x_ = torch.view_as_complex(
|
||||
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
|
||||
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
|
||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
||||
-1).transpose(1, 2)
|
||||
# Reshape the output tensor to the original shape.
|
||||
return x_out.reshape(x_out.shape[0], x_out.shape[1], -1)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = self._norm(x.float())
|
||||
output = x * (1 + self.weight.float())
|
||||
return output.to(orig_dtype)
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.gate_proj = Linear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=False,
|
||||
)
|
||||
self.up_proj = Linear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = Linear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate = self.gate_proj(x)
|
||||
gate = F.gelu(gate, approximate="tanh")
|
||||
up = self.up_proj(x)
|
||||
fuse = gate * up
|
||||
outputs = self.down_proj(fuse)
|
||||
return outputs
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = Linear(self.hidden_size,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=False)
|
||||
self.k_proj = Linear(self.hidden_size,
|
||||
self.num_kv_heads * self.head_dim,
|
||||
bias=False)
|
||||
self.v_proj = Linear(self.hidden_size,
|
||||
self.num_kv_heads * self.head_dim,
|
||||
bias=False)
|
||||
self.o_proj = Linear(self.num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
q = apply_rotary_emb(q, freqs_cis)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
k = apply_rotary_emb(k, freqs_cis)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
o = self.o_proj(attn_output)
|
||||
return o
|
||||
|
||||
|
||||
class GemmaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = GemmaAttention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
self.mlp = GemmaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
freqs_cis=freqs_cis,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(config.num_hidden_layers):
|
||||
self.layers.append(GemmaDecoderLayer(config))
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
# Gemma normalizes the embedding by sqrt(hidden_size).
|
||||
# FIXME(woosuk): Downcast the normalizer.
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
freqs_cis=freqs_cis,
|
||||
kv_cache=kv_caches[i],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GemmaForCausalLM(PreTrainedModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.model = GemmaModel(config)
|
||||
rope_theta = getattr(config, 'rope_theta', 10000)
|
||||
# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
|
||||
freqs_cis = precompute_freqs_cis(config.head_dim,
|
||||
config.max_position_embeddings * 2,
|
||||
theta=rope_theta)
|
||||
self.register_buffer('freqs_cis', freqs_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
freqs_cis = self.freqs_cis.index_select(0, positions)
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
freqs_cis=freqs_cis,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
Loading…
x
Reference in New Issue
Block a user