ChatGLM Support (#1261)

This commit is contained in:
GoHomeToMacDonal 2023-11-07 08:09:33 +08:00 committed by GitHub
parent e7f579eb97
commit 1a2bbc9301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 490 additions and 4 deletions

View File

@ -166,6 +166,10 @@ class ModelConfig:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
# For ChatGLM-2:
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
return (self.hf_config.multi_query_group_num //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

View File

@ -18,6 +18,7 @@ _MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,

View File

@ -13,6 +13,7 @@ from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.mpt import MptForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
from vllm.model_executor.models.yi import YiForCausalLM
__all__ = [
@ -20,6 +21,7 @@ __all__ = [
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"ChatGLMForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",

View File

@ -0,0 +1,408 @@
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.parallel_utils.layers import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs import ChatGLMConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GLMAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = ColumnParallelLinear(
config.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.add_qkv_bias,
gather_output=False,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
input_is_parallel=True,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim // 2,
num_kv_heads=self.num_kv_heads,
is_neox_style=False,
# is_glm_style=True
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
key_cache, value_cache = kv_cache
context_layer = self.attn(
position_ids,
q,
k,
v,
key_cache,
value_cache,
input_metadata,
cache_event,
)
attn_output, _ = self.dense(context_layer)
return attn_output
class GLMMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config):
super().__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.ffn_hidden_size * 2,
bias=config.add_bias_linear,
gather_output=False,
)
self.activation_func = SiluAndMul()
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
input_is_parallel=True,
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, _ = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = residual + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.mlp(layernorm_output) + residual
return output
class GLMTransformer(nn.Module):
"""Transformer class."""
def __init__(self, config):
super().__init__()
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
for i in range(self.num_layers):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
input_metadata=input_metadata,
cache_event=cache_event,
)
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class ChatGLMModel(nn.Module):
def __init__(self, config):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config)
self.output_layer = ColumnParallelLinear(
config.hidden_size,
config.padded_vocab_size,
bias=False,
gather_output=False,
params_dtype=config.torch_dtype,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
):
inputs_embeds = self.embedding(input_ids)
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
return hidden_states
class ChatGLMForCausalLM(nn.Module):
def __init__(self, config: ChatGLMConfig):
super().__init__()
self.config: ChatGLMConfig = config
self.transformer = ChatGLMModel(config)
self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"output_layer.weight",
"embedding.weight",
]
_row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
q_proj_shard_size = self.config.hidden_size // tp_size
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.multi_query_group_num // tp_size)
mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
if name in state_dict:
param = state_dict[name]
if "query_key_value" in name:
q_offset = q_proj_shard_size * tp_rank
k_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * tp_rank)
v_offset = (q_proj_shard_size * tp_size +
kv_proj_shard_size * (tp_size + tp_rank))
wq = loaded_weight[q_offset:q_offset + q_proj_shard_size]
wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size]
wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param.data.copy_(loaded_weight)
continue
if "dense_h_to_4h" in name:
w_gate = loaded_weight[mlp_hidden_shard_size *
tp_rank:mlp_hidden_shard_size *
(tp_rank + 1)]
w_proj = loaded_weight[mlp_hidden_shard_size *
(tp_size +
tp_rank):mlp_hidden_shard_size *
(tp_size + tp_rank + 1)]
loaded_weight = torch.cat([w_gate, w_proj], dim=0)
param.data.copy_(loaded_weight)
continue
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
elif name == "transformer.rotary_pos_emb.inv_freq":
continue
else:
print("Warning never found tensor's name:", name)

View File

@ -5,9 +5,10 @@ from transformers import AutoConfig, MptConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = {
"mpt": MptConfig,
"baichuan": BaiChuanConfig,
"aquila": AquilaConfig,
"baichuan": BaiChuanConfig,
"chatglm": ChatGLMConfig,
"mpt": MptConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)

View File

@ -1,5 +1,6 @@
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.aquila import AquilaConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
@ -8,8 +9,9 @@ from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.yi import YiConfig
__all__ = [
"BaiChuanConfig",
"AquilaConfig",
"BaiChuanConfig",
"ChatGLMConfig",
"QWenConfig",
"RWConfig",
"YiConfig",

View File

@ -0,0 +1,68 @@
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
attribute_map = {
"num_hidden_layers": "num_layers",
"n_head_kv": "multi_query_group_num",
}
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
interleaved_qkv=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm)
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
self.interleaved_qkv = interleaved_qkv
super().__init__(**kwargs)