mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 04:15:01 +08:00
340 lines
12 KiB
Python
340 lines
12 KiB
Python
# coding=utf-8
|
|
# Adapted from
|
|
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
|
|
# Copyright 2023 The vLLM team.
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
#
|
|
# BSD 3-Clause License
|
|
#
|
|
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PhiConfig
|
|
|
|
from vllm.attention import Attention, AttentionMetadata
|
|
from vllm.config import CacheConfig, LoRAConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig)
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.sampler import Sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
|
|
|
from .interfaces import SupportsLoRA
|
|
|
|
|
|
class PhiAttention(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: PhiConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
self.total_num_heads = config.num_attention_heads
|
|
self.hidden_size = config.hidden_size
|
|
self.head_size = self.hidden_size // self.total_num_heads
|
|
|
|
tensor_model_parallel_world_size = (
|
|
get_tensor_model_parallel_world_size())
|
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
self.num_heads = (self.total_num_heads //
|
|
tensor_model_parallel_world_size)
|
|
|
|
# pylint: disable=C0103
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_size,
|
|
self.total_num_heads,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
)
|
|
self.dense = RowParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
scaling = self.head_size**-0.5
|
|
rotary_dim = int(config.partial_rotary_factor *
|
|
(config.hidden_size // config.num_attention_heads))
|
|
assert rotary_dim % 2 == 0
|
|
|
|
# pylint: disable=C0301
|
|
# Refer to:
|
|
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
|
|
rope_theta = 10000
|
|
max_position_embeddings = getattr(config, "n_positions", 2048)
|
|
self.rotary_emb = get_rope(
|
|
self.head_size,
|
|
rotary_dim=rotary_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
)
|
|
self.attn = Attention(self.num_heads,
|
|
self.head_size,
|
|
scaling,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config)
|
|
|
|
def forward(
|
|
self,
|
|
position_ids: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
q, k = self.rotary_emb(position_ids, q, k)
|
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
output, _ = self.dense(attn_output)
|
|
return output
|
|
|
|
|
|
class PhiMLP(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: PhiConfig,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
|
|
n_inner = getattr(config, "n_inner", None)
|
|
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
|
|
|
|
self.fc1 = ColumnParallelLinear(
|
|
config.hidden_size,
|
|
n_inner,
|
|
quant_config=quant_config,
|
|
)
|
|
self.fc2 = RowParallelLinear(
|
|
n_inner,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class PhiLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: PhiConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
self.self_attn = PhiAttention(config, cache_config, quant_config)
|
|
self.mlp = PhiMLP(config, quant_config)
|
|
|
|
def forward(
|
|
self,
|
|
position_ids: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
attn_outputs = self.self_attn(
|
|
position_ids=position_ids,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
|
return hidden_states
|
|
|
|
|
|
class PhiModel(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: PhiConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
|
config.hidden_size)
|
|
self.layers = nn.ModuleList([
|
|
PhiLayer(config, cache_config, quant_config)
|
|
for _ in range(config.num_hidden_layers)
|
|
])
|
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
for i in range(self.config.num_hidden_layers):
|
|
layer = self.layers[i]
|
|
hidden_states = layer(
|
|
positions,
|
|
hidden_states,
|
|
kv_caches[i],
|
|
attn_metadata,
|
|
)
|
|
|
|
hidden_states = self.final_layernorm(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class PhiForCausalLM(nn.Module, SupportsLoRA):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
]
|
|
}
|
|
|
|
# LoRA specific attributes
|
|
supported_lora_modules = [
|
|
"qkv_proj",
|
|
"dense",
|
|
"fc1",
|
|
"fc2",
|
|
]
|
|
embedding_modules = {}
|
|
embedding_padding_modules = []
|
|
|
|
def __init__(
|
|
self,
|
|
config: PhiConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.lora_config = lora_config
|
|
|
|
self.quant_config = quant_config
|
|
|
|
self.model = PhiModel(config, cache_config, quant_config)
|
|
|
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
config.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = Sampler()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata, self.lm_head.bias)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v")
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# pylint: disable=E1136
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|