diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index a3fd24c911b1..c93d2db812e4 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { - "BaiChuanForCausalLM": BaiChuanForCausalLM, + "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b + "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b "BloomForCausalLM": BloomForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c3e3e5723e53..d3259a05104f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,4 +1,4 @@ -from vllm.model_executor.models.baichuan import BaiChuanForCausalLM +from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM @@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM __all__ = [ "BaiChuanForCausalLM", + "BaichuanForCausalLM", "BloomForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index ff5afe51d80e..3ec9ddbacadc 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -22,6 +22,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ +import math from typing import Dict, List, Optional, Tuple import torch @@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) @@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig KVCache = Tuple[torch.Tensor, torch.Tensor] +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + class BaiChuanMLP(nn.Module): def __init__( @@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module): self, hidden_size: int, num_heads: int, + position_embedding: str, ): super().__init__() self.hidden_size = hidden_size @@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module): self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - self.scaling = self.head_dim**-0.5 + self.postion_embedding = position_embedding # pylint: disable=invalid-name self.W_pack = ColumnParallelLinear( @@ -109,11 +136,23 @@ class BaiChuanAttention(nn.Module): input_is_parallel=True, perform_initialization=False, ) + # Create the alibi slopes and slice them. + if self.postion_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - rotary_dim=self.head_dim) + scaling = self.head_dim**-0.5 + self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, + scaling, alibi_slopes) + else: + self.scaling = self.head_dim**-0.5 + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim) def forward( self, @@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module): qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + if self.postion_embedding == "ALIBI": + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) + else: + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + output, _ = self.o_proj(attn_output) return output class BaiChuanDecoderLayer(nn.Module): - def __init__(self, config: BaiChuanConfig): + def __init__(self, config: BaiChuanConfig, position_embedding: str): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, + position_embedding=position_embedding, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, @@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module): class BaiChuanModel(nn.Module): - def __init__(self, config: BaiChuanConfig): + def __init__(self, config: BaiChuanConfig, position_embedding: str): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module): config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config) + BaiChuanDecoderLayer(config, position_embedding) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module): return hidden_states -class BaiChuanForCausalLM(nn.Module): +class BaiChuanBaseForCausalLM(nn.Module): - def __init__(self, config): + def __init__(self, config, position_embedding: str): super().__init__() self.config = config - self.model = BaiChuanModel(config) + self.model = BaiChuanModel(config, position_embedding) self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, @@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module): self._row_parallel_weights, tp_rank, ) + + +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b + + def __init__(self, config): + super().__init__(config, "ALIBI") + + +class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b + + def __init__(self, config): + super().__init__(config, "ROPE")