mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
Use Llama RMSNorm custom op for Gemma (#2974)
This commit is contained in:
parent
344020c926
commit
95529e3253
@ -22,6 +22,7 @@ from transformers import GemmaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -40,21 +41,6 @@ from vllm.sequence import SamplerOutput
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * (1 + self.weight)
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -185,10 +171,10 @@ class GemmaDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -196,25 +182,27 @@ class GemmaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
@ -235,7 +223,7 @@ class GemmaModel(nn.Module):
|
||||
GemmaDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -246,17 +234,19 @@ class GemmaModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
hidden_states *= self.config.hidden_size**0.5
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -321,6 +311,10 @@ class GemmaForCausalLM(nn.Module):
|
||||
# Skip loading extra layer for lora models.
|
||||
if "lm_head" in name:
|
||||
continue
|
||||
# GemmaRMSNorm is different from Llama's in that it multiplies
|
||||
# (1 + weight) to the output, instead of just weight.
|
||||
if "norm.weight" in name:
|
||||
loaded_weight += 1.0
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
@ -329,5 +323,5 @@ class GemmaForCausalLM(nn.Module):
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||
)
|
||||
"Some weights are not initialized from checkpoints: "
|
||||
f"{unloaded_params}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user