mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 23:25:44 +08:00
[Bugfix] Fix precisions in Gemma 1 (#5913)
This commit is contained in:
parent
ba4994443a
commit
580353da93
@ -17,6 +17,7 @@ MODELS = [
|
|||||||
"stabilityai/stablelm-3b-4e1t",
|
"stabilityai/stablelm-3b-4e1t",
|
||||||
# "allenai/OLMo-1B", # Broken
|
# "allenai/OLMo-1B", # Broken
|
||||||
"bigcode/starcoder2-3b",
|
"bigcode/starcoder2-3b",
|
||||||
|
"google/gemma-1.1-2b-it",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -26,14 +26,14 @@ from vllm.config import CacheConfig, LoRAConfig
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -148,12 +148,14 @@ class GemmaAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
# TODO(woosuk): Use the `get_rope` interface.
|
||||||
|
self.rotary_emb = GemmaRotaryEmbedding(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -204,10 +206,10 @@ class GemmaDecoderLayer(nn.Module):
|
|||||||
hidden_activation=getattr(config, "hidden_activation", None),
|
hidden_activation=getattr(config, "hidden_activation", None),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -257,7 +259,7 @@ class GemmaModel(nn.Module):
|
|||||||
GemmaDecoderLayer(config, cache_config, quant_config)
|
GemmaDecoderLayer(config, cache_config, quant_config)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
# Normalize the embedding by sqrt(hidden_size)
|
# Normalize the embedding by sqrt(hidden_size)
|
||||||
# The normalizer's data type should be downcasted to the model's
|
# The normalizer's data type should be downcasted to the model's
|
||||||
@ -331,7 +333,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -388,10 +389,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# GemmaRMSNorm is different from Llama's in that it multiplies
|
|
||||||
# (1 + weight) to the output, instead of just weight.
|
|
||||||
if "norm.weight" in name:
|
|
||||||
loaded_weight += 1.0
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user