[Bugfix][Model] Refactor OLMo model to support new HF format in transformers 4.40.0 (#4324)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Isotr0py 2024-04-26 00:35:56 +08:00 committed by GitHub
parent 479d69fad0
commit fbf152d976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 150 additions and 162 deletions

View File

@ -74,7 +74,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) - OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)

View File

@ -101,7 +101,7 @@ Alongside each architecture, we include some popular models that use it.
- -
* - :code:`OLMoForCausalLM` * - :code:`OLMoForCausalLM`
- OLMo - OLMo
- :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
- -
* - :code:`OPTForCausalLM` * - :code:`OPTForCausalLM`
- OPT, OPT-IML - OPT, OPT-IML

View File

@ -26,7 +26,6 @@ requests
ray ray
peft peft
awscli awscli
ai2-olmo # required for OLMo
# Benchmarking # Benchmarking
aiohttp aiohttp

View File

@ -42,7 +42,7 @@ _MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),

View File

@ -1,53 +1,36 @@
# coding=utf-8 # coding=utf-8
# Adapted from # Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py # Copyright 2024 The vLLM team.
# Copyright 2023 The vLLM team. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# #
# BSD 3-Clause License # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
# #
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. # Licensed under the Apache License, Version 2.0 (the "License");
# All rights reserved. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# Redistribution and use in source and binary forms, with or without # http://www.apache.org/licenses/LICENSE-2.0
# modification, are permitted provided that the following conditions are met:
# #
# * Redistributions of source code must retain the above copyright notice, this # Unless required by applicable law or agreed to in writing, software
# list of conditions and the following disclaimer. # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * Redistributions in binary form must reproduce the above copyright notice, # See the License for the specific language governing permissions and
# this list of conditions and the following disclaimer in the documentation # limitations under the License.
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
# this model must need this dependency
from hf_olmo import OLMoConfig
from torch import nn from torch import nn
from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (LinearMethodBase,
LinearMethodBase,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
@ -55,7 +38,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
@ -70,55 +53,52 @@ class OlmoAttention(nn.Module):
def __init__( def __init__(
self, self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.d_model self.hidden_size = config.hidden_size
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
self.total_num_heads = self.config.n_heads self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.clip_qkv = config.clip_qkv
# Layer norms.
self.attn_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Attention input projection. Projects x -> (q, k, v) # Attention input projection. Projects x -> (q, k, v)
self.att_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
config.d_model, self.hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=config.include_bias, bias=config.attention_bias,
linear_method=linear_method, linear_method=linear_method,
) )
# Rotary embeddings. # Rotary embeddings.
if self.config.rope: self.rotary_emb = get_rope(
rope_theta = getattr(config, "rope_theta", 10000) self.head_dim,
max_position_embeddings = getattr(config, rotary_dim=self.head_dim,
"max_position_embeddings", 8192) max_position=self.max_position_embeddings,
self.rotary_emb = get_rope( base=self.rope_theta,
self.head_dim, )
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling) scale=self.scaling)
# Attention output projection. # Attention output projection.
self.attn_out = RowParallelLinear( self.o_proj = RowParallelLinear(
config.d_model, self.hidden_size,
config.d_model, self.hidden_size,
bias=config.include_bias, bias=config.attention_bias,
linear_method=linear_method, linear_method=linear_method,
) )
@ -129,13 +109,13 @@ class OlmoAttention(nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.attn_norm(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
qkv, _ = self.att_proj(hidden_states) if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.config.rope: q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.attn_out(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -148,37 +128,30 @@ class OlmoMLP(nn.Module):
def __init__( def __init__(
self, self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size self.hidden_size = config.hidden_size
is not None else config.mlp_ratio * config.d_model) self.intermediate_size = config.intermediate_size
# Layer norms.
self.ff_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Feed-forward input projection. # Feed-forward input projection.
self.ff_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
config.d_model, self.hidden_size,
[self.hidden_size // 2] * 2, [self.intermediate_size] * 2,
bias=config.include_bias, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
# Activation function. # Activation function.
self.act = SiluAndMul() self.act_fn = SiluAndMul()
self.act.output_multiplier = 0.5
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection. # Feed-forward output projection.
self.ff_out = RowParallelLinear( self.down_proj = RowParallelLinear(
int(self.act.output_multiplier * self.hidden_size), self.intermediate_size,
config.d_model, self.hidden_size,
bias=config.include_bias, bias=False,
linear_method=linear_method, linear_method=linear_method,
) )
@ -186,19 +159,13 @@ class OlmoMLP(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Add feed-forward projection. gate_up, _ = self.gate_up_proj(x)
# shape: (batch_size, seq_len, d_model) x = self.act_fn(gate_up)
og_x = x x, _ = self.down_proj(x)
x = self.ff_norm(x)
x, _ = self.ff_proj(x)
x = self.act(x)
x, _ = self.ff_out(x)
x = og_x + x
return x return x
class OlmoBlock(nn.Module): class OlmoDecoderLayer(nn.Module):
""" """
This is a typical transformer block where the output is This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))`` computed as ``MLP(LN(x + Attention(LN(x))))``
@ -206,15 +173,23 @@ class OlmoBlock(nn.Module):
""" """
def __init__(self, def __init__(self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.attn = OlmoAttention(config, linear_method) self.self_attn = OlmoAttention(config, linear_method)
# MLP block. # MLP block.
self.mlp = OlmoMLP(config, linear_method) self.mlp = OlmoMLP(config, linear_method)
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False,
bias=False)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
@ -223,52 +198,37 @@ class OlmoBlock(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block. # Attention block.
og_x = hidden_states residual = hidden_states
x = self.attn(positions, hidden_states, kv_cache, attn_metadata) hidden_states = self.input_layernorm(hidden_states)
x = x + og_x hidden_states = self.self_attn(positions, hidden_states, kv_cache,
attn_metadata)
hidden_states = hidden_states + residual
# MLP block. # MLP block.
hidden_states = self.mlp(x) residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states return hidden_states
class OlmoModel(nn.Module): class OlmoModel(nn.Module):
def __init__(self, def __init__(self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = nn.ModuleDict( self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
dict( config.hidden_size)
wte=VocabParallelEmbedding( self.layers = nn.ModuleList([
config.embedding_size or config.vocab_size, OlmoDecoderLayer(config, linear_method)
config.d_model, for layer_idx in range(config.num_hidden_layers)
), ])
ln_f=nn.LayerNorm(config.d_model, self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False, elementwise_affine=False,
bias=False), bias=False)
))
blocks = [
OlmoBlock(config, linear_method) for i in range(config.n_layers)
]
if self.config.block_group_size > 1:
raise NotImplementedError("Block group size > 1 not supported yet")
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not config.weight_tying:
self.transformer.update({
"ff_out":
ColumnParallelLinear(
config.d_model,
config.embedding_size or config.vocab_size,
bias=config.include_bias,
linear_method=linear_method,
)
})
def forward( def forward(
self, self,
@ -282,39 +242,49 @@ class OlmoModel(nn.Module):
""" """
# Get embeddings of input. # Get embeddings of input.
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# Apply blocks one-by-one. # Apply blocks one-by-one.
for block_idx, block in enumerate(self.transformer.blocks): for layer_idx, decoder_layer in enumerate(self.layers):
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
x = block( hidden_states = decoder_layer(
positions, positions,
x, hidden_states,
kv_caches[block_idx], kv_caches[layer_idx],
attn_metadata, attn_metadata,
) )
# Apply final layer norm. # Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model) # shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore hidden_states = self.norm(hidden_states)
return x return hidden_states
class OLMoForCausalLM(nn.Module): class OlmoForCausalLM(nn.Module):
""" """
Extremely barebones HF model wrapper. Extremely barebones HF model wrapper.
""" """
def __init__(self, def __init__(self,
config: OLMoConfig, config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = OlmoModel(config, linear_method) self.model = OlmoModel(config, linear_method)
self.lm_head_weight = (self.model.transformer.wte.weight if config.tie_word_embeddings:
if config.weight_tying else self.lm_head_weight = self.model.embed_tokens.weight
self.model.transformer.ff_out.weight) else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
@ -348,20 +318,39 @@ class OLMoForCausalLM(nn.Module):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
# attention if "rotary_emb.inv_freq" in name:
if ".att" in name: continue
name = name.replace(".att", ".attn.att") if ("rotary_emb.cos_cached" in name
# mlp or "rotary_emb.sin_cached" in name):
if ".ff_proj" in name: # Models trained using ColossalAI may include these tensors in
name = name.replace(".ff_proj", ".mlp.ff_proj") # the checkpoint. Skip them.
# Reverse the weight for the MergeColumnParallelLinear continue
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) for (param_name, weight_name, shard_id) in stacked_params_mapping:
if ".ff_out" in name and "transformer.ff_out" not in name: if weight_name not in name:
name = name.replace(".ff_out", ".mlp.ff_out") continue
# there is no bias in olmo name = name.replace(weight_name, param_name)
param = params_dict[name] # Skip loading extra bias for GPTQ models.
weight_loader = getattr(param, "weight_loader", if name.endswith(".bias") and name not in params_dict:
default_weight_loader) continue
weight_loader(param, loaded_weight) param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)