Use Llama RMSNorm custom op for Gemma (#2974)

This commit is contained in:
Woosuk Kwon 2024-02-21 18:28:23 -08:00 committed by GitHub
parent 344020c926
commit 95529e3253
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,7 @@ from transformers import GemmaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
@ -40,21 +41,6 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * (1 + self.weight)
class GemmaMLP(nn.Module):
def __init__(
@ -185,10 +171,10 @@ class GemmaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
linear_method=linear_method,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
@ -196,25 +182,27 @@ class GemmaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, residual
class GemmaModel(nn.Module):
@ -235,7 +223,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
@ -246,17 +234,19 @@ class GemmaModel(nn.Module):
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
hidden_states *= self.config.hidden_size**0.5
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
residual,
)
hidden_states = self.norm(hidden_states)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@ -321,6 +311,10 @@ class GemmaForCausalLM(nn.Module):
# Skip loading extra layer for lora models.
if "lm_head" in name:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
@ -329,5 +323,5 @@ class GemmaForCausalLM(nn.Module):
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")