fix baichuan for different position embedding for 7b and 13b models (#643)

This commit is contained in:
Song 2023-08-02 13:22:51 +08:00 committed by GitHub
parent d4c7755ca8
commit 64f23c2900
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 16 deletions

View File

@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM, "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM, "BloomForCausalLM": BloomForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,

View File

@ -1,4 +1,4 @@
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [ __all__ = [
"BaiChuanForCausalLM", "BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM", "BloomForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",

View File

@ -22,6 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module): class BaiChuanMLP(nn.Module):
def __init__( def __init__(
@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
position_embedding: str,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5 self.postion_embedding = position_embedding
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear( self.W_pack = ColumnParallelLinear(
@ -109,7 +136,19 @@ class BaiChuanAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
class BaiChuanDecoderLayer(nn.Module): class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig): def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = BaiChuanAttention( self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig): def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config) BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
return hidden_states return hidden_states
class BaiChuanForCausalLM(nn.Module): class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding: str):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = BaiChuanModel(config) self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size, config.vocab_size,
bias=False, bias=False,
@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
self._row_parallel_weights, self._row_parallel_weights,
tp_rank, tp_rank,
) )
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")