mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 17:34:49 +08:00
Add TPU gemma
This commit is contained in:
parent
563c1d7ec5
commit
52a1e908e4
0
vllm/model_executor/models/tpu/__init__.py
Normal file
0
vllm/model_executor/models/tpu/__init__.py
Normal file
295
vllm/model_executor/models/tpu/gemma.py
Normal file
295
vllm/model_executor/models/tpu/gemma.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""Inference-only Gemma model compatible with HF weights.
|
||||
|
||||
NOTE(woosuk): This is a temporary workaround to run the Gemma model using
|
||||
PyTorch XLA. This should be merged into the main Gemma model implementation
|
||||
once the custom ops are refactored and the model becomes torch.compile-able.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import GemmaConfig, PreTrainedModel
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""PyTorch Linear layer without parameter initialization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
"""Precomputes the frequency cis."""
|
||||
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the rotary embedding to the query and key tensors."""
|
||||
x_ = torch.view_as_complex(
|
||||
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
|
||||
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
|
||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
||||
-1).transpose(1, 2)
|
||||
return x_out
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = self._norm(x.float())
|
||||
output = x * (1 + self.weight.float())
|
||||
return output.to(orig_dtype)
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.gate_proj = Linear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=False,
|
||||
)
|
||||
self.up_proj = Linear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = Linear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate = self.gate_proj(x)
|
||||
gate = F.gelu(gate, approximate="tanh")
|
||||
up = self.up_proj(x)
|
||||
fuse = gate * up
|
||||
outputs = self.down_proj(fuse)
|
||||
return outputs
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = Linear(self.hidden_size,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=False)
|
||||
self.k_proj = Linear(self.hidden_size,
|
||||
self.num_kv_heads * self.head_dim,
|
||||
bias=False)
|
||||
self.v_proj = Linear(self.hidden_size,
|
||||
self.num_kv_heads * self.head_dim,
|
||||
bias=False)
|
||||
self.o_proj = Linear(self.num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
q = apply_rotary_emb(q, freqs_cis)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
k = apply_rotary_emb(k, freqs_cis)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
o = self.o_proj(attn_output)
|
||||
return o
|
||||
|
||||
|
||||
class GemmaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = GemmaAttention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
self.mlp = GemmaMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
freqs_cis=freqs_cis,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(config.num_hidden_layers):
|
||||
self.layers.append(GemmaDecoderLayer(config))
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
# Gemma normalizes the embedding by sqrt(hidden_size).
|
||||
# FIXME(woosuk): Downcast the normalizer.
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
freqs_cis=freqs_cis,
|
||||
kv_cache=kv_caches[i],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GemmaForCausalLM(PreTrainedModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.model = GemmaModel(config)
|
||||
rope_theta = getattr(config, 'rope_theta', 10000)
|
||||
# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
|
||||
freqs_cis = precompute_freqs_cis(config.head_dim,
|
||||
config.max_position_embeddings * 2,
|
||||
theta=rope_theta)
|
||||
self.register_buffer('freqs_cis', freqs_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
freqs_cis = self.freqs_cis.index_select(0, positions)
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
freqs_cis=freqs_cis,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
Loading…
x
Reference in New Issue
Block a user