mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
Add Gemma model (#2964)
This commit is contained in:
parent
017d9f1515
commit
5253edaacb
@ -20,6 +20,7 @@ _MODELS = {
|
|||||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||||
|
|||||||
333
vllm/model_executor/models/gemma.py
Normal file
333
vllm/model_executor/models/gemma.py
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright (c) Google Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import GemmaConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
LinearMethodBase,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
|
hf_model_weights_iterator)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.zeros(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * (1 + self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method)
|
||||||
|
self.up_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method)
|
||||||
|
self.act_fn = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate, _ = self.gate_proj(x)
|
||||||
|
gate = self.act_fn(gate)
|
||||||
|
up, _ = self.up_proj(x)
|
||||||
|
fuse = gate * up
|
||||||
|
outputs, _ = self.down_proj(fuse)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GemmaConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = GemmaAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
rope_theta=config.rope_theta,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.mlp = GemmaMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GemmaConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
GemmaDecoderLayer(config, linear_method)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
# Normalize the embedding by sqrt(hidden_size)
|
||||||
|
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||||
|
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: GemmaConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.model = GemmaModel(config, linear_method)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(self.model.embed_tokens.weight,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params = set()
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(shard_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra layer for lora models.
|
||||||
|
if "lm_head" in name:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
|
if unloaded_params:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user