mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
parent
79af7e96a0
commit
a57d13cc96
@ -23,6 +23,7 @@ _MODEL_REGISTRY = {
|
||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"QWenLMHeadModel": QWenLMHeadModel,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
@ -9,17 +9,11 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||
|
||||
__all__ = [
|
||||
"BaiChuanForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
"FalconForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
"GPTJForCausalLM",
|
||||
"GPTNeoXForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"MPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"BaiChuanForCausalLM", "BaichuanForCausalLM", "BloomForCausalLM",
|
||||
"FalconForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM",
|
||||
"GPTJForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM",
|
||||
"MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel"
|
||||
]
|
||||
|
||||
316
vllm/model_executor/models/qwen.py
Normal file
316
vllm/model_executor/models/qwen.py
Normal file
@ -0,0 +1,316 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||
"""Inference-only QWen model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int,
|
||||
max_position_embeddings: int):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
)
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
|
||||
config.max_position_embeddings)
|
||||
|
||||
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QWenModel(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size,
|
||||
config.n_embd,
|
||||
perform_initialization=False)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QWenLMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = QWenModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "lm_head.weight"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False,
|
||||
):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "wte" in name or "lm_head" in name:
|
||||
# Consider padding in the vocab size.
|
||||
param = state_dict[name]
|
||||
padded_vocab_size = param.shape[0] * tp_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
if "c_attn" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
if "weight" in name:
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "bias" in name:
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["w2", "w1"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
@ -5,6 +5,7 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
||||
_CONFIG_REGISTRY = {
|
||||
"mpt": MPTConfig,
|
||||
"baichuan": BaiChuanConfig,
|
||||
"qwen": QWenConfig,
|
||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
@ -8,5 +9,6 @@ from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
__all__ = [
|
||||
"MPTConfig",
|
||||
"BaiChuanConfig",
|
||||
"QWenConfig",
|
||||
"RWConfig",
|
||||
]
|
||||
|
||||
71
vllm/transformers_utils/configs/qwen.py
Normal file
71
vllm/transformers_utils/configs/qwen.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class QWenConfig(PretrainedConfig):
|
||||
model_type = "qwen"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"max_position_embeddings": "n_positions",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151851,
|
||||
n_embd=4096,
|
||||
n_layer=32,
|
||||
n_head=32,
|
||||
n_inner=None,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
scale_attn_weights=True,
|
||||
use_cache=True,
|
||||
eos_token_id=151643,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
bf16=True,
|
||||
kv_channels=128,
|
||||
rotary_pct=1.0,
|
||||
rotary_emb_base=10000,
|
||||
use_dynamic_ntk=False,
|
||||
use_logn_attn=False,
|
||||
use_flash_attn=True,
|
||||
ffn_hidden_size=22016,
|
||||
no_bias=True,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
super().__init__(eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_inner = n_inner
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
apply_residual_connection_post_layernorm)
|
||||
self.bf16 = bf16
|
||||
self.kv_channels = kv_channels
|
||||
self.rotary_pct = rotary_pct
|
||||
self.rotary_emb_base = rotary_emb_base
|
||||
self.use_dynamic_ntk = use_dynamic_ntk
|
||||
self.use_logn_attn = use_logn_attn
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.no_bias = no_bias
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
Loading…
x
Reference in New Issue
Block a user