mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 00:16:36 +08:00
331 lines
12 KiB
Python
331 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
|
import math
|
|
from collections.abc import Iterable
|
|
from itertools import islice
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import MptConfig
|
|
|
|
from vllm.attention import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .interfaces import SupportsPP
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
|
|
def _get_alibi_slopes(
|
|
total_num_heads: int,
|
|
alibi_bias_max: int,
|
|
) -> torch.Tensor:
|
|
next_power_of_2 = 2 ** math.ceil(math.log2(total_num_heads))
|
|
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
|
|
m = m.mul(alibi_bias_max / next_power_of_2)
|
|
slopes = 1.0 / torch.pow(2, m)
|
|
if next_power_of_2 != total_num_heads:
|
|
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
|
|
return slopes
|
|
|
|
|
|
class MPTAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: MptConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.d_model = config.d_model
|
|
self.total_num_heads = config.n_heads
|
|
self.head_dim = self.d_model // self.total_num_heads
|
|
self.clip_qkv = config.attn_config.clip_qkv
|
|
self.qk_ln = config.attn_config.qk_ln
|
|
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
|
if "kv_n_heads" in config.attn_config:
|
|
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
|
else:
|
|
self.total_num_kv_heads = self.total_num_heads
|
|
assert not config.attn_config.prefix_lm
|
|
assert config.attn_config.alibi
|
|
|
|
# pylint: disable=invalid-name
|
|
self.Wqkv = QKVParallelLinear(
|
|
self.d_model,
|
|
self.d_model // self.total_num_heads,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=not config.no_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
if self.qk_ln:
|
|
self.q_ln = nn.LayerNorm(self.d_model)
|
|
self.k_ln = nn.LayerNorm(self.d_model)
|
|
self.out_proj = RowParallelLinear(
|
|
self.d_model,
|
|
self.d_model,
|
|
bias=not config.no_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
tp_world_size = get_tensor_model_parallel_world_size()
|
|
assert self.total_num_heads % tp_world_size == 0
|
|
self.num_heads = self.total_num_heads // tp_world_size
|
|
|
|
if self.total_num_kv_heads >= tp_world_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_world_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_world_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
# Create the alibi slopes and slice them.
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
head_start = tp_rank * self.num_heads
|
|
head_end = (tp_rank + 1) * self.num_heads
|
|
alibi_slopes = _get_alibi_slopes(self.total_num_heads, self.alibi_bias_max)
|
|
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
|
|
|
self.head_dim = self.d_model // self.total_num_heads
|
|
scaling = self.head_dim**-0.5
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
scaling,
|
|
alibi_slopes=alibi_slopes,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
position_ids: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
del position_ids # unused.
|
|
qkv, _ = self.Wqkv(hidden_states)
|
|
if self.clip_qkv is not None:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
if self.qk_ln:
|
|
q = self.q_ln(q)
|
|
k = self.k_ln(k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.out_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class MPTMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: MptConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
):
|
|
super().__init__()
|
|
hidden_size = config.d_model
|
|
expansion_ratio = config.expansion_ratio
|
|
intermediate_size = expansion_ratio * hidden_size
|
|
self.up_proj = ColumnParallelLinear(
|
|
hidden_size,
|
|
intermediate_size,
|
|
bias=not config.no_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
self.act = get_act_fn("gelu")
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=not config.no_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x, _ = self.up_proj(x)
|
|
x = self.act(x)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class MPTBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: MptConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
hidden_size = config.d_model
|
|
self.norm_1 = nn.LayerNorm(hidden_size)
|
|
self.attn = MPTAttention(
|
|
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
|
)
|
|
self.norm_2 = nn.LayerNorm(hidden_size)
|
|
self.ffn = MPTMLP(config, quant_config)
|
|
|
|
def forward(
|
|
self,
|
|
position_ids: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
x = self.norm_1(hidden_states)
|
|
x = self.attn(
|
|
position_ids=position_ids,
|
|
hidden_states=x,
|
|
)
|
|
hidden_states = hidden_states + x
|
|
x = self.norm_2(hidden_states)
|
|
x = self.ffn(x)
|
|
hidden_states = hidden_states + x
|
|
return hidden_states
|
|
|
|
|
|
@support_torch_compile
|
|
class MPTModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
assert config.embedding_fraction == 1.0
|
|
assert config.norm_type == "low_precision_layernorm"
|
|
|
|
self.wte = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.d_model,
|
|
)
|
|
self.start_layer, self.end_layer, self.blocks = make_layers(
|
|
config.n_layers,
|
|
lambda prefix: MPTBlock(config, cache_config, quant_config, prefix=prefix),
|
|
prefix=f"{prefix}.blocks",
|
|
)
|
|
self.norm_f = nn.LayerNorm(config.d_model)
|
|
if config.no_bias:
|
|
for module in self.modules():
|
|
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
|
# Remove the bias term in Linear and LayerNorm.
|
|
module.register_parameter("bias", None)
|
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states"], config.d_model
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.wte(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
|
|
for block in islice(self.blocks, self.start_layer, self.end_layer):
|
|
hidden_states = block(position_ids, hidden_states)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({"hidden_states": hidden_states})
|
|
hidden_states = self.norm_f(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class MPTForCausalLM(nn.Module, SupportsPP):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
assert config.tie_word_embeddings
|
|
self.quant_config = quant_config
|
|
|
|
self.transformer = MPTModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
|
|
)
|
|
self.lm_head = self.transformer.wte
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.transformer.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.transformer.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
hidden_states = self.transformer(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights)
|