From 298695b76691867ecd320ea6a2c6d0c6a843d5ae Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Thu, 22 Jun 2023 19:49:27 +0200 Subject: [PATCH] GPTBigCode (StarCoder, SantaCoder Support) (#209) --- vllm/model_executor/layers/activation.py | 1 + vllm/model_executor/model_loader.py | 3 +- vllm/model_executor/models/__init__.py | 5 +- vllm/model_executor/models/gpt_bigcode.py | 291 ++++++++++++++++++++++ 4 files changed, 298 insertions(+), 2 deletions(-) create mode 100644 vllm/model_executor/models/gpt_bigcode.py diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index a0b3f6ff653a..ce41cfe0ad11 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -8,6 +8,7 @@ _ACTIVATION_REGISTRY = { "gelu": nn.GELU(), "gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "relu": nn.ReLU(), } diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ec5b75d08d8b..93faa116c68a 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -6,13 +6,14 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config import ModelConfig -from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM, +from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM) from vllm.model_executor.weight_utils import initialize_dummy_weights # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { "GPT2LMHeadModel": GPT2LMHeadModel, + "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "OPTForCausalLM": OPTForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 098b7d448b60..0636b6792541 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,11 +1,14 @@ -from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel +from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM +from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM + __all__ = [ "GPT2LMHeadModel", + "GPTBigCodeForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM", "OPTForCausalLM", diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py new file mode 100644 index 000000000000..3bd3c6fb1898 --- /dev/null +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPTBigCode model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +import numpy as np +from transformers import GPTBigCodeConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPTBigCodeAttention(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim ** -0.5 + + self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, + bias=True, gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, + bias=True, input_is_parallel=True, + perform_initialization=False) + self.attn = PagedAttention(self.num_heads, self.head_dim, + scale=self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn( + q, k, v, key_cache, value_cache, input_metadata, cache_event) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, + bias=True, gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(intermediate_size, hidden_size, + bias=True, input_is_parallel=True, + perform_initialization=False) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.config = config + assert config.add_cross_attention == False + + self.embed_dim = config.hidden_size + + # Optimization: While the vocab size of GPT-2 is 50257, we extend it + # to 50304 in order to make it divisible by 64. + # This improves performance since GPUs are faster if the dimension + # is divisible by 64. In addition, it allows us to shard the embedding + # layer across 2, 4, 8, or more GPUs. + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList( + [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + hidden_states, kv_caches[i], input_metadata, cache_event) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module): + + def __init__(self, config: GPTBigCodeConfig): + super().__init__() + self.config = config + self.transformer = GPTBigCodeModel(config) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head_weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] + _row_parallel_weights = ["c_proj.weight"] + + def load_weights(self, model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + + param = state_dict[name] + + def _expand_mqa_mha(qkv_array, n_head, head_dim): + """manipulates along axis=0 from MQA to MHA + inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim) + with n_heads for q, then 1 for k, 1 for 1 v, times head dim + return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim) + + TODO: this function is no longer needed once vllm supports MQA. + """ + qkv_array = qkv_array.numpy() + + dims_q = n_head * head_dim + q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0) + # q is fine, but k & v have not replicated shape along the first axis + # as long as MQA is not nativly supported, increase memory and replicated + # (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim) + if k.ndim == 2 and v.ndim == 2: + replication = (n_head, 1) # weights + else: + replication = n_head # biases + # replicate n_head times for q, v + k, v = np.tile(k, replication), np.tile(v, replication) + # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) + # to (3 * n_heads * head_dim, hidden_dim) + qkv_array = np.concatenate((q, k, v), axis=0) + return torch.from_numpy(qkv_array) + + # For the fused QKV linear layer, manually shard the weights. + if "c_attn" in name: + # GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. + # When tensor parallelism is used, we shard the weights along the head dimension. + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tensor_model_parallel_world_size + head_start = tensor_model_parallel_rank * num_heads + head_end = (tensor_model_parallel_rank + 1) * num_heads + + if name.endswith(".weight"): + loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) + loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif name.endswith(".bias"): + loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) + loaded_weight = loaded_weight[:, head_start:head_end, :] + loaded_weight = loaded_weight.reshape(-1) + else: + raise ValueError(f"Unexpected parameter name {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, + tensor_model_parallel_rank)