From 20b0d88d1630aec3a18a80590080b0ab1d16969f Mon Sep 17 00:00:00 2001 From: codethazine <41583025+codethazine@users.noreply.github.com> Date: Mon, 17 Jul 2023 21:50:55 +0100 Subject: [PATCH] Add support for baichuan (#365) --- vllm/model_executor/model_loader.py | 1 + vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/baichuan.py | 293 ++++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/baichuan.py | 62 +++++ 6 files changed, 361 insertions(+) create mode 100644 vllm/model_executor/models/baichuan.py create mode 100644 vllm/transformers_utils/configs/baichuan.py diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 40e5f583cad2..a3fd24c911b1 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -11,6 +11,7 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { + "BaiChuanForCausalLM": BaiChuanForCausalLM, "BloomForCausalLM": BloomForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8717c5682773..c3e3e5723e53 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,3 +1,4 @@ +from vllm.model_executor.models.baichuan import BaiChuanForCausalLM from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM @@ -8,6 +9,7 @@ from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM __all__ = [ + "BaiChuanForCausalLM", "BloomForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py new file mode 100644 index 000000000000..3696c56a903f --- /dev/null +++ b/vllm/model_executor/models/baichuan.py @@ -0,0 +1,293 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BaiChuan model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.sequence import SequenceOutputs +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.transformers_utils.configs.baichuan import BaiChuanConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BaiChuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_up_proj = ColumnParallelLinear(hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class BaiChuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.scaling = self.head_dim**-0.5 + + # pylint: disable=invalid-name + self.W_pack = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + output, _ = self.o_proj(attn_output) + return output + + +class BaiChuanDecoderLayer(nn.Module): + + def __init__(self, config: BaiChuanConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BaiChuanAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + ) + self.mlp = BaiChuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class BaiChuanModel(nn.Module): + + def __init__(self, config: BaiChuanConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + perform_initialization=False) + self.layers = nn.ModuleList([ + BaiChuanDecoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class BaiChuanForCausalLM(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.model = BaiChuanModel(config) + self.lm_head = ColumnParallelLinear(config.hidden_size, + config.vocab_size, + bias=False, + gather_output=False, + perform_initialization=False) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_weights = [ + "embed_tokens.weight", "lm_head.weight", "W_pack.weight", + "gate_proj.weight", "up_proj.weight" + ] + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index aeb4a6aacb53..71e67118a599 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,6 +4,7 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import _CONFIG_REGISTRY = { "mpt": MPTConfig, + "baichuan": BaiChuanConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index fef4caf93e43..5f0ba4eb9fea 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,5 +1,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.baichuan import BaiChuanConfig __all__ = [ "MPTConfig", + "BaiChuanConfig", ] diff --git a/vllm/transformers_utils/configs/baichuan.py b/vllm/transformers_utils/configs/baichuan.py new file mode 100644 index 000000000000..869817525c11 --- /dev/null +++ b/vllm/transformers_utils/configs/baichuan.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.configuration_utils import PretrainedConfig + + +class BaiChuanConfig(PretrainedConfig): + model_type = "baichuan" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=64000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + )