[Bugfix] Fix precisions in Gemma 1 (#5913)

This commit is contained in:
Woosuk Kwon 2024-06-28 20:10:21 -07:00 committed by GitHub
parent ba4994443a
commit 580353da93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 14 deletions

View File

@ -17,6 +17,7 @@ MODELS = [
"stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken
"bigcode/starcoder2-3b",
"google/gemma-1.1-2b-it",
]

View File

@ -26,14 +26,14 @@ from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -148,12 +148,14 @@ class GemmaAttention(nn.Module):
quant_config=quant_config,
)
self.rotary_emb = get_rope(
# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
max_position_embeddings=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)
self.attn = Attention(self.num_heads,
self.head_dim,
@ -204,10 +206,10 @@ class GemmaDecoderLayer(nn.Module):
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
@ -257,7 +259,7 @@ class GemmaModel(nn.Module):
GemmaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
@ -331,7 +333,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
@ -388,10 +389,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)