Support OLMo models. (#2832)

This commit is contained in:
Isotr0py 2024-02-19 13:05:15 +08:00 committed by GitHub
parent a61f0521b8
commit ab3a5a8259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 471 additions and 5 deletions

View File

@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)

View File

@ -62,6 +62,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OLMoForCausalLM`
- OLMo
- :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.

View File

@ -5,11 +5,20 @@ Run `pytest tests/models/test_models.py --forked`.
import pytest
MODELS = [
"facebook/opt-125m", "meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2",
"bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b",
"microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"Deci/DeciLM-7b",
"tiiuae/falcon-7b",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"mosaicml/mpt-7b",
"microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t",
"allenai/OLMo-1B",
]

View File

@ -35,6 +35,7 @@ _MODELS = {
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),

View File

@ -0,0 +1,378 @@
# coding=utf-8
# Adapted from
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, )
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.olmo import OLMoConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.hidden_size = config.d_model
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = self.config.n_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
# Layer norms.
self.attn_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Attention input projection. Projects x -> (q, k, v)
self.att_proj = QKVParallelLinear(
config.d_model,
self.head_dim,
self.total_num_heads,
bias=config.include_bias,
linear_method=linear_method,
)
# Rotary embeddings.
if self.config.rope:
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
# Attention output projection.
self.attn_out = RowParallelLinear(
config.d_model,
config.d_model,
bias=config.include_bias,
linear_method=linear_method,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.attn_norm(hidden_states)
qkv, _ = self.att_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.config.rope:
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.attn_out(attn_output)
return output
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(
self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
is not None else config.mlp_ratio * config.d_model)
# Layer norms.
self.ff_norm = nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False)
# Feed-forward input projection.
self.ff_proj = ColumnParallelLinear(
config.d_model,
self.hidden_size,
bias=config.include_bias,
linear_method=linear_method,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self.act = SwiGLU()
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection.
self.ff_out = RowParallelLinear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
linear_method=linear_method,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self.ff_norm(x)
x, _ = self.ff_proj(x)
x = self.act(x)
x, _ = self.ff_out(x)
x = og_x + x
return x
class OlmoBlock(nn.Module):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
# Attention block.
self.attn = OlmoAttention(config, linear_method)
# MLP block.
self.mlp = OlmoMLP(config, linear_method)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
og_x = hidden_states
x = self.attn(positions, hidden_states, kv_cache, input_metadata)
x = x + og_x
# MLP block.
hidden_states = self.mlp(x)
return hidden_states
class OlmoModel(nn.Module):
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(
dict(
wte=VocabParallelEmbedding(
config.embedding_size or config.vocab_size,
config.d_model,
),
ln_f=nn.LayerNorm(config.d_model,
elementwise_affine=False,
bias=False),
))
blocks = [
OlmoBlock(config, linear_method) for i in range(config.n_layers)
]
if self.config.block_group_size > 1:
raise NotImplementedError("Block group size > 1 not supported yet")
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not config.weight_tying:
self.transformer.update({
"ff_out":
ColumnParallelLinear(
config.d_model,
config.embedding_size or config.vocab_size,
bias=config.include_bias,
linear_method=linear_method,
)
})
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore
# Apply blocks one-by-one.
for block_idx, block in enumerate(self.transformer.blocks):
# shape: (batch_size, seq_len, d_model)
x = block(
positions,
x,
kv_caches[block_idx],
input_metadata,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore
return x
class OLMoForCausalLM(nn.Module):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self,
config: OLMoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OlmoModel(config, linear_method)
self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else
self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
# attention
if ".att" in name:
name = name.replace(".att", ".attn.att")
# mlp
if ".ff" in name and "transformer.ff_out" not in name:
name = name.replace(".ff", ".mlp.ff")
# there is no bias in olmo
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -1,6 +1,7 @@
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.olmo import OLMoConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
@ -11,6 +12,7 @@ __all__ = [
"BaiChuanConfig",
"ChatGLMConfig",
"MPTConfig",
"OLMoConfig",
"QWenConfig",
"RWConfig",
]

View File

@ -0,0 +1,72 @@
# coding=utf-8
# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
"""OLMo configuration"""
from transformers import PretrainedConfig
class OLMoConfig(PretrainedConfig):
model_type = 'olmo'
attribute_map = {
'num_attention_heads': 'n_heads',
'hidden_size': 'd_model',
'num_hidden_layers': 'n_layers',
}
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
def __init__(
self,
d_model=768,
n_heads=12,
n_layers=12,
mlp_ratio=4,
mlp_hidden_size=None,
activation_type="swiglu",
block_type="sequential",
block_group_size=1,
alibi=False,
alibi_bias_max=8.0,
rope=False,
rope_full_precision=True,
multi_query_attention=False,
attention_layer_norm=False,
layer_norm_type="default",
layer_norm_with_affine=True,
attention_layer_norm_with_affine=True,
max_sequence_length=1024,
include_bias=True,
bias_for_layer_norm=None,
scale_logits=False,
vocab_size=50257,
embedding_size=50304,
weight_tying=True,
eos_token_id=50256,
pad_token_id=50256,
**kwargs,
):
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.mlp_ratio = mlp_ratio
self.mlp_hidden_size = mlp_hidden_size
self.activation_type = activation_type
self.block_type = block_type
self.block_group_size = block_group_size
self.alibi = alibi
self.alibi_bias_max = alibi_bias_max
self.rope = rope
self.rope_full_precision = rope_full_precision
self.multi_query_attention = multi_query_attention
self.attention_layer_norm = attention_layer_norm
self.layer_norm_type = layer_norm_type
self.layer_norm_with_affine = layer_norm_with_affine
self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
self.max_sequence_length = max_sequence_length
self.include_bias = include_bias
self.bias_for_layer_norm = bias_for_layer_norm
self.scale_logits = scale_logits
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.weight_tying = weight_tying
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
super().__init__(**kwargs)