mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:44:59 +08:00
fix baichuan for different position embedding for 7b and 13b models (#643)
This commit is contained in:
parent
d4c7755ca8
commit
64f23c2900
@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
|
|||||||
|
|
||||||
# TODO(woosuk): Lazy-load the model classes.
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
_MODEL_REGISTRY = {
|
_MODEL_REGISTRY = {
|
||||||
"BaiChuanForCausalLM": BaiChuanForCausalLM,
|
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||||
|
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||||
"BloomForCausalLM": BloomForCausalLM,
|
"BloomForCausalLM": BloomForCausalLM,
|
||||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM
|
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
|
||||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||||
@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaiChuanForCausalLM",
|
"BaiChuanForCausalLM",
|
||||||
|
"BaichuanForCausalLM",
|
||||||
"BloomForCausalLM",
|
"BloomForCausalLM",
|
||||||
"GPT2LMHeadModel",
|
"GPT2LMHeadModel",
|
||||||
"GPTBigCodeForCausalLM",
|
"GPTBigCodeForCausalLM",
|
||||||
|
|||||||
@ -22,6 +22,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
|
|||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights)
|
||||||
@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
|||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||||
|
base = torch.tensor(
|
||||||
|
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||||
|
slopes = torch.pow(base, powers)
|
||||||
|
|
||||||
|
if closest_power_of_2 != total_num_heads:
|
||||||
|
extra_base = torch.tensor(
|
||||||
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
num_remaining_heads = min(closest_power_of_2,
|
||||||
|
total_num_heads - closest_power_of_2)
|
||||||
|
extra_powers = torch.arange(start=1,
|
||||||
|
end=1 + 2 * num_remaining_heads,
|
||||||
|
step=2,
|
||||||
|
dtype=torch.int32)
|
||||||
|
slopes = torch.cat(
|
||||||
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanMLP(nn.Module):
|
class BaiChuanMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
position_embedding: str,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.num_heads = (self.total_num_heads //
|
self.num_heads = (self.total_num_heads //
|
||||||
tensor_model_parallel_world_size)
|
tensor_model_parallel_world_size)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
self.scaling = self.head_dim**-0.5
|
self.postion_embedding = position_embedding
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
self.W_pack = ColumnParallelLinear(
|
self.W_pack = ColumnParallelLinear(
|
||||||
@ -109,7 +136,19 @@ class BaiChuanAttention(nn.Module):
|
|||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
)
|
)
|
||||||
|
# Create the alibi slopes and slice them.
|
||||||
|
if self.postion_embedding == "ALIBI":
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
head_start = tp_rank * self.num_heads
|
||||||
|
head_end = (tp_rank + 1) * self.num_heads
|
||||||
|
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||||
|
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||||
|
|
||||||
|
scaling = self.head_dim**-0.5
|
||||||
|
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||||
|
scaling, alibi_slopes)
|
||||||
|
else:
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
|
|||||||
qkv, _ = self.W_pack(hidden_states)
|
qkv, _ = self.W_pack(hidden_states)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
if self.postion_embedding == "ALIBI":
|
||||||
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||||
|
cache_event)
|
||||||
|
else:
|
||||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
input_metadata, cache_event)
|
input_metadata, cache_event)
|
||||||
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanDecoderLayer(nn.Module):
|
class BaiChuanDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: BaiChuanConfig):
|
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = BaiChuanAttention(
|
self.self_attn = BaiChuanAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
|
position_embedding=position_embedding,
|
||||||
)
|
)
|
||||||
self.mlp = BaiChuanMLP(
|
self.mlp = BaiChuanMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class BaiChuanModel(nn.Module):
|
class BaiChuanModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: BaiChuanConfig):
|
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
BaiChuanDecoderLayer(config)
|
BaiChuanDecoderLayer(config, position_embedding)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanForCausalLM(nn.Module):
|
class BaiChuanBaseForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, position_embedding: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = BaiChuanModel(config)
|
self.model = BaiChuanModel(config, position_embedding)
|
||||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
|
|||||||
self._row_parallel_weights,
|
self._row_parallel_weights,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config, "ALIBI")
|
||||||
|
|
||||||
|
|
||||||
|
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config, "ROPE")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user