mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
Support OLMo models. (#2832)
This commit is contained in:
parent
a61f0521b8
commit
ab3a5a8259
@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
|
||||
@ -62,6 +62,9 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
* - :code:`OLMoForCausalLM`
|
||||
- OLMo
|
||||
- :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
|
||||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
|
||||
@ -5,11 +5,20 @@ Run `pytest tests/models/test_models.py --forked`.
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m", "meta-llama/Llama-2-7b-hf",
|
||||
"mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2",
|
||||
"bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b",
|
||||
"microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
"Deci/DeciLM-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"microsoft/phi-2",
|
||||
"stabilityai/stablelm-3b-4e1t",
|
||||
"allenai/OLMo-1B",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -35,6 +35,7 @@ _MODELS = {
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
|
||||
378
vllm/model_executor/models/olmo.py
Normal file
378
vllm/model_executor/models/olmo.py
Normal file
@ -0,0 +1,378 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
|
||||
# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size, )
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.olmo import OLMoConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
|
||||
@property
|
||||
def output_multiplier(self) -> float:
|
||||
return 0.5
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
"""
|
||||
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
||||
(plus another skip connection).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OLMoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.d_model
|
||||
assert config.d_model % config.n_heads == 0
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
)
|
||||
self.total_num_heads = self.config.n_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
|
||||
# Layer norms.
|
||||
self.attn_norm = nn.LayerNorm(config.d_model,
|
||||
elementwise_affine=False,
|
||||
bias=False)
|
||||
# Attention input projection. Projects x -> (q, k, v)
|
||||
self.att_proj = QKVParallelLinear(
|
||||
config.d_model,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=config.include_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
# Rotary embeddings.
|
||||
if self.config.rope:
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling)
|
||||
|
||||
# Attention output projection.
|
||||
self.attn_out = RowParallelLinear(
|
||||
config.d_model,
|
||||
config.d_model,
|
||||
bias=config.include_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.attn_norm(hidden_states)
|
||||
qkv, _ = self.att_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.config.rope:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.attn_out(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class OlmoMLP(nn.Module):
|
||||
"""
|
||||
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
||||
(plus another skip connection).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OLMoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
|
||||
is not None else config.mlp_ratio * config.d_model)
|
||||
|
||||
# Layer norms.
|
||||
self.ff_norm = nn.LayerNorm(config.d_model,
|
||||
elementwise_affine=False,
|
||||
bias=False)
|
||||
|
||||
# Feed-forward input projection.
|
||||
self.ff_proj = ColumnParallelLinear(
|
||||
config.d_model,
|
||||
self.hidden_size,
|
||||
bias=config.include_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
# self.act = SiluAndMul()
|
||||
# self.act.output_multiplier = 0.5
|
||||
self.act = SwiGLU()
|
||||
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
||||
|
||||
# Feed-forward output projection.
|
||||
self.ff_out = RowParallelLinear(
|
||||
int(self.act.output_multiplier * self.hidden_size),
|
||||
config.d_model,
|
||||
bias=config.include_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Add feed-forward projection.
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
og_x = x
|
||||
x = self.ff_norm(x)
|
||||
x, _ = self.ff_proj(x)
|
||||
x = self.act(x)
|
||||
x, _ = self.ff_out(x)
|
||||
x = og_x + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OlmoBlock(nn.Module):
|
||||
"""
|
||||
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
||||
(plus another skip connection).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: OLMoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.attn = OlmoAttention(config, linear_method)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = OlmoMLP(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Attention block.
|
||||
og_x = hidden_states
|
||||
x = self.attn(positions, hidden_states, kv_cache, input_metadata)
|
||||
x = x + og_x
|
||||
|
||||
# MLP block.
|
||||
hidden_states = self.mlp(x)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OlmoModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OLMoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
wte=VocabParallelEmbedding(
|
||||
config.embedding_size or config.vocab_size,
|
||||
config.d_model,
|
||||
),
|
||||
ln_f=nn.LayerNorm(config.d_model,
|
||||
elementwise_affine=False,
|
||||
bias=False),
|
||||
))
|
||||
|
||||
blocks = [
|
||||
OlmoBlock(config, linear_method) for i in range(config.n_layers)
|
||||
]
|
||||
if self.config.block_group_size > 1:
|
||||
raise NotImplementedError("Block group size > 1 not supported yet")
|
||||
else:
|
||||
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
||||
|
||||
if not config.weight_tying:
|
||||
self.transformer.update({
|
||||
"ff_out":
|
||||
ColumnParallelLinear(
|
||||
config.d_model,
|
||||
config.embedding_size or config.vocab_size,
|
||||
bias=config.include_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
||||
"""
|
||||
# Get embeddings of input.
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
x = self.transformer.wte(input_ids) # type: ignore
|
||||
|
||||
# Apply blocks one-by-one.
|
||||
for block_idx, block in enumerate(self.transformer.blocks):
|
||||
# shape: (batch_size, seq_len, d_model)
|
||||
x = block(
|
||||
positions,
|
||||
x,
|
||||
kv_caches[block_idx],
|
||||
input_metadata,
|
||||
)
|
||||
|
||||
# Apply final layer norm.
|
||||
# shape: (batch_size, seq_len or 1, d_model)
|
||||
x = self.transformer.ln_f(x) # type: ignore
|
||||
return x
|
||||
|
||||
|
||||
class OLMoForCausalLM(nn.Module):
|
||||
"""
|
||||
Extremely barebones HF model wrapper.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: OLMoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = OlmoModel(config, linear_method)
|
||||
self.lm_head_weight = (self.model.transformer.wte.weight
|
||||
if config.weight_tying else
|
||||
self.model.transformer.ff_out.weight)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
# attention
|
||||
if ".att" in name:
|
||||
name = name.replace(".att", ".attn.att")
|
||||
# mlp
|
||||
if ".ff" in name and "transformer.ff_out" not in name:
|
||||
name = name.replace(".ff", ".mlp.ff")
|
||||
# there is no bias in olmo
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
@ -1,6 +1,7 @@
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.olmo import OLMoConfig
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
@ -11,6 +12,7 @@ __all__ = [
|
||||
"BaiChuanConfig",
|
||||
"ChatGLMConfig",
|
||||
"MPTConfig",
|
||||
"OLMoConfig",
|
||||
"QWenConfig",
|
||||
"RWConfig",
|
||||
]
|
||||
|
||||
72
vllm/transformers_utils/configs/olmo.py
Normal file
72
vllm/transformers_utils/configs/olmo.py
Normal file
@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py
|
||||
"""OLMo configuration"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class OLMoConfig(PretrainedConfig):
|
||||
model_type = 'olmo'
|
||||
attribute_map = {
|
||||
'num_attention_heads': 'n_heads',
|
||||
'hidden_size': 'd_model',
|
||||
'num_hidden_layers': 'n_layers',
|
||||
}
|
||||
|
||||
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
||||
def __init__(
|
||||
self,
|
||||
d_model=768,
|
||||
n_heads=12,
|
||||
n_layers=12,
|
||||
mlp_ratio=4,
|
||||
mlp_hidden_size=None,
|
||||
activation_type="swiglu",
|
||||
block_type="sequential",
|
||||
block_group_size=1,
|
||||
alibi=False,
|
||||
alibi_bias_max=8.0,
|
||||
rope=False,
|
||||
rope_full_precision=True,
|
||||
multi_query_attention=False,
|
||||
attention_layer_norm=False,
|
||||
layer_norm_type="default",
|
||||
layer_norm_with_affine=True,
|
||||
attention_layer_norm_with_affine=True,
|
||||
max_sequence_length=1024,
|
||||
include_bias=True,
|
||||
bias_for_layer_norm=None,
|
||||
scale_logits=False,
|
||||
vocab_size=50257,
|
||||
embedding_size=50304,
|
||||
weight_tying=True,
|
||||
eos_token_id=50256,
|
||||
pad_token_id=50256,
|
||||
**kwargs,
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp_hidden_size = mlp_hidden_size
|
||||
self.activation_type = activation_type
|
||||
self.block_type = block_type
|
||||
self.block_group_size = block_group_size
|
||||
self.alibi = alibi
|
||||
self.alibi_bias_max = alibi_bias_max
|
||||
self.rope = rope
|
||||
self.rope_full_precision = rope_full_precision
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.attention_layer_norm = attention_layer_norm
|
||||
self.layer_norm_type = layer_norm_type
|
||||
self.layer_norm_with_affine = layer_norm_with_affine
|
||||
self.attention_layer_norm_with_affine = attention_layer_norm_with_affine
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.include_bias = include_bias
|
||||
self.bias_for_layer_norm = bias_for_layer_norm
|
||||
self.scale_logits = scale_logits
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.weight_tying = weight_tying
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
super().__init__(**kwargs)
|
||||
Loading…
x
Reference in New Issue
Block a user