mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 02:24:58 +08:00
[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
416 lines
16 KiB
Python
416 lines
16 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
|
import math
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.attention import Attention, AttentionMetadata
|
|
from vllm.config import LoRAConfig
|
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size)
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.sampler import Sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
|
hf_model_weights_iterator)
|
|
from vllm.sequence import SamplerOutput
|
|
|
|
|
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
|
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
|
base = torch.tensor(
|
|
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
|
dtype=torch.float32,
|
|
)
|
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
|
slopes = torch.pow(base, powers)
|
|
|
|
if closest_power_of_2 != total_num_heads:
|
|
extra_base = torch.tensor(
|
|
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
|
dtype=torch.float32,
|
|
)
|
|
num_remaining_heads = min(closest_power_of_2,
|
|
total_num_heads - closest_power_of_2)
|
|
extra_powers = torch.arange(start=1,
|
|
end=1 + 2 * num_remaining_heads,
|
|
step=2,
|
|
dtype=torch.int32)
|
|
slopes = torch.cat(
|
|
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
return slopes
|
|
|
|
|
|
class BaiChuanMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
):
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size, [intermediate_size] * 2,
|
|
bias=False,
|
|
linear_method=linear_method)
|
|
self.down_proj = RowParallelLinear(intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
linear_method=linear_method)
|
|
if hidden_act != "silu":
|
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now.")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class BaiChuanAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
position_embedding: str,
|
|
rope_theta: float = 10000,
|
|
max_position_embeddings: int = 8192,
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
|
)
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
self.num_heads = (self.total_num_heads //
|
|
tensor_model_parallel_world_size)
|
|
self.head_dim = hidden_size // self.total_num_heads
|
|
self.postion_embedding = position_embedding
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
# pylint: disable=invalid-name
|
|
self.W_pack = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_heads,
|
|
bias=False,
|
|
linear_method=linear_method,
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=False,
|
|
linear_method=linear_method,
|
|
)
|
|
# Create the alibi slopes and slice them.
|
|
if self.postion_embedding == "ALIBI":
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
head_start = tp_rank * self.num_heads
|
|
head_end = (tp_rank + 1) * self.num_heads
|
|
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
|
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
|
|
|
scaling = self.head_dim**-0.5
|
|
self.attn = Attention(self.num_heads,
|
|
self.head_dim,
|
|
scaling,
|
|
alibi_slopes=alibi_slopes)
|
|
else:
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=self.max_position_embeddings,
|
|
base=self.rope_theta,
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.W_pack(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
if self.postion_embedding != "ALIBI":
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class BaiChuanDecoderLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: PretrainedConfig,
|
|
position_embedding: str,
|
|
linear_method: Optional[LinearMethodBase] = None):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
|
8192)
|
|
self.self_attn = BaiChuanAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
position_embedding=position_embedding,
|
|
rope_theta=rope_theta,
|
|
max_position_embeddings=max_position_embeddings,
|
|
linear_method=linear_method,
|
|
)
|
|
self.mlp = BaiChuanMLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
linear_method=linear_method,
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
residual: Optional[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class BaiChuanModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: PretrainedConfig,
|
|
position_embedding: str,
|
|
linear_method: Optional[LinearMethodBase] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
self.layers = nn.ModuleList([
|
|
BaiChuanDecoderLayer(config, position_embedding, linear_method)
|
|
for _ in range(config.num_hidden_layers)
|
|
])
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
residual = None
|
|
for i in range(len(self.layers)):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
kv_caches[i],
|
|
attn_metadata,
|
|
residual,
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class BaiChuanBaseForCausalLM(nn.Module):
|
|
packed_modules_mapping = {
|
|
"W_pack": ["W_pack"],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
# LoRA specific attributes
|
|
supported_lora_modules = [
|
|
"W_pack",
|
|
"o_proj",
|
|
"gate_up_proj",
|
|
"down_proj",
|
|
]
|
|
embedding_modules = {}
|
|
embedding_padding_modules = []
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
position_embedding: str,
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.linear_method = linear_method
|
|
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = Sampler()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata)
|
|
return hidden_states
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self,
|
|
model_name_or_path: str,
|
|
cache_dir: Optional[str] = None,
|
|
load_format: str = "auto",
|
|
revision: Optional[str] = None):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
for name, loaded_weight in hf_model_weights_iterator(
|
|
model_name_or_path, cache_dir, load_format, revision):
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
if name == "lm_head.weight":
|
|
# Unlike Baichuan, Baichuan2 normalizes the head weights.
|
|
# Refer to:
|
|
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
|
# Distinguish between Baichuan and Baichuan2 by checking the
|
|
# vocab size. This is suggested by
|
|
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
|
is_baichuan2 = self.config.vocab_size == 125696
|
|
if is_baichuan2:
|
|
loaded_weight = torch.nn.functional.normalize(
|
|
loaded_weight)
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
"""Baichuan 13B and Baichuan2 7B/13B."""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
):
|
|
if config.hidden_size == 4096: # baichuan2 7b
|
|
super().__init__(config, "ROPE", linear_method, lora_config)
|
|
else: # baichuan 13b, baichuan2 13b
|
|
super().__init__(config, "ALIBI", linear_method, lora_config)
|
|
|
|
|
|
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
"""Baichuan 7B."""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
linear_method: Optional[LinearMethodBase] = None,
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
):
|
|
super().__init__(config, "ROPE", linear_method, lora_config)
|