mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Support Microsoft Phi 1.5 (#1664)
This commit is contained in:
parent
cb08cd0d75
commit
521b35f799
@ -59,6 +59,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
|||||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
|
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|||||||
@ -5,6 +5,7 @@ pandas # Required for Ray data.
|
|||||||
pyarrow # Required for Ray data.
|
pyarrow # Required for Ray data.
|
||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
|
einops # Required for phi-1_5
|
||||||
torch >= 2.1.0
|
torch >= 2.1.0
|
||||||
transformers >= 4.34.0 # Required for Mistral.
|
transformers >= 4.34.0 # Required for Mistral.
|
||||||
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
|
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
|
||||||
|
|||||||
@ -15,6 +15,7 @@ MODELS = [
|
|||||||
"EleutherAI/pythia-70m",
|
"EleutherAI/pythia-70m",
|
||||||
"bigscience/bloom-560m",
|
"bigscience/bloom-560m",
|
||||||
"mosaicml/mpt-7b",
|
"mosaicml/mpt-7b",
|
||||||
|
"microsoft/phi-1_5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ _MODEL_REGISTRY = {
|
|||||||
"MptForCausalLM": MPTForCausalLM,
|
"MptForCausalLM": MPTForCausalLM,
|
||||||
"MPTForCausalLM": MPTForCausalLM,
|
"MPTForCausalLM": MPTForCausalLM,
|
||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
|
"PhiForCausalLM": PhiForCausalLM,
|
||||||
"QWenLMHeadModel": QWenLMHeadModel,
|
"QWenLMHeadModel": QWenLMHeadModel,
|
||||||
"RWForCausalLM": FalconForCausalLM,
|
"RWForCausalLM": FalconForCausalLM,
|
||||||
"YiForCausalLM": YiForCausalLM,
|
"YiForCausalLM": YiForCausalLM,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
|||||||
from vllm.model_executor.models.mistral import MistralForCausalLM
|
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
|
||||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||||
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
|
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
|
||||||
from vllm.model_executor.models.yi import YiForCausalLM
|
from vllm.model_executor.models.yi import YiForCausalLM
|
||||||
@ -31,6 +32,7 @@ __all__ = [
|
|||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"MPTForCausalLM",
|
"MPTForCausalLM",
|
||||||
"OPTForCausalLM",
|
"OPTForCausalLM",
|
||||||
|
"PhiForCausalLM",
|
||||||
"QWenLMHeadModel",
|
"QWenLMHeadModel",
|
||||||
"MistralForCausalLM",
|
"MistralForCausalLM",
|
||||||
"YiForCausalLM",
|
"YiForCausalLM",
|
||||||
|
|||||||
314
vllm/model_executor/models/phi_1_5.py
Normal file
314
vllm/model_executor/models/phi_1_5.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
#
|
||||||
|
# BSD 3-Clause License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# * Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# * Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
LinearMethodBase,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
|
hf_model_weights_iterator)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class PhiEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.wte = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.LongTensor):
|
||||||
|
return self.wte(input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class PhiAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.total_num_heads
|
||||||
|
|
||||||
|
tensor_model_parallel_world_size = (
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = (self.total_num_heads //
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
|
||||||
|
# pylint: disable=C0103
|
||||||
|
self.Wqkv = QKVParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.head_size,
|
||||||
|
self.total_num_heads,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.head_size,
|
||||||
|
self.total_num_heads,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaling = self.head_size**-0.5
|
||||||
|
rotary_dim = config.rotary_dim
|
||||||
|
assert rotary_dim % 2 == 0
|
||||||
|
|
||||||
|
# pylint: disable=C0301
|
||||||
|
# Refer to:
|
||||||
|
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
|
||||||
|
rope_theta = 10000
|
||||||
|
max_position_embeddings = getattr(config, "n_positions", 2048)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
rotary_dim,
|
||||||
|
base=rope_theta,
|
||||||
|
max_position=max_position_embeddings)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.Wqkv(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.out_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PhiMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
n_inner = getattr(config, "n_inner", None)
|
||||||
|
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
|
||||||
|
|
||||||
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
n_inner,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.fc2 = RowParallelLinear(
|
||||||
|
n_inner,
|
||||||
|
config.hidden_size,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.act = get_act_fn(config.activation_function)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PhiLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.ln = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
self.mixer = PhiAttention(config, linear_method)
|
||||||
|
self.mlp = PhiMLP(config, linear_method)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ln(hidden_states)
|
||||||
|
attn_outputs = self.mixer(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PhiCausalLMHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.ln = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
self.linear = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=True)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
):
|
||||||
|
hidden_states = self.ln(hidden_states)
|
||||||
|
next_tokens = self.sampler(self.linear.weight, hidden_states,
|
||||||
|
input_metadata, self.linear.bias)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class PhiModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.embd = PhiEmbedding(config)
|
||||||
|
self.h = nn.ModuleList([
|
||||||
|
PhiLayer(config, linear_method)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.embd(input_ids)
|
||||||
|
for i in range(self.config.num_hidden_layers):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.h[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PhiForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
|
||||||
|
self.transformer = PhiModel(config, linear_method)
|
||||||
|
self.lm_head = PhiCausalLMHead(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
lm_logits = self.lm_head(hidden_states, input_metadata)
|
||||||
|
return lm_logits
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# pylint: disable=E1136
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
Loading…
x
Reference in New Issue
Block a user