mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:35:01 +08:00
414 lines
14 KiB
Python
414 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
|
|
# Copyright 2024 The vLLM team.
|
|
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
|
|
|
from collections.abc import Iterable
|
|
from itertools import islice
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import OlmoConfig
|
|
|
|
from vllm.attention.layer import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .interfaces import SupportsLoRA, SupportsPP
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
|
|
class OlmoAttention(nn.Module):
|
|
"""
|
|
This is the attention block where the output is computed as
|
|
`Attention(LN(x))` in `MLP(LN(x + Attention(LN(x))))`
|
|
(plus another skip connection).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: OlmoConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
|
|
assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
|
|
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
|
self.head_dim = self.hidden_size // self.total_num_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.clip_qkv = config.clip_qkv
|
|
|
|
# Attention input projection. Projects x -> (q, k, v)
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
bias=config.attention_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
|
|
# Rotary embeddings.
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=self.max_position_embeddings,
|
|
rope_parameters=config.rope_parameters,
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
scale=self.scaling,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
# Attention output projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_size,
|
|
bias=config.attention_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
if self.clip_qkv is not None:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class OlmoMLP(nn.Module):
|
|
"""
|
|
This is the MLP block where the output is computed as
|
|
`MLP(LN(x))` in `MLP(LN(x + Attention(LN(x))))`
|
|
(plus another skip connection).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: OlmoConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
# Feed-forward input projection.
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
self.hidden_size,
|
|
[self.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
|
|
# Activation function.
|
|
self.act_fn = SiluAndMul()
|
|
|
|
# Feed-forward output projection.
|
|
self.down_proj = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class OlmoDecoderLayer(nn.Module):
|
|
"""
|
|
This is a typical transformer block where the output is
|
|
computed as `MLP(LN(x + Attention(LN(x))))`
|
|
(plus another skip connection).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: OlmoConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
# Attention block.
|
|
self.self_attn = OlmoAttention(
|
|
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
|
|
)
|
|
|
|
# MLP block.
|
|
self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
|
|
|
# LayerNorm
|
|
self.input_layernorm = nn.LayerNorm(
|
|
config.hidden_size, elementwise_affine=False, bias=False
|
|
)
|
|
self.post_attention_layernorm = nn.LayerNorm(
|
|
config.hidden_size, elementwise_affine=False, bias=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
|
# Attention block.
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.self_attn(positions, hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
|
|
# MLP block.
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
@support_torch_compile
|
|
class OlmoModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size, config.hidden_size
|
|
)
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: OlmoDecoderLayer(
|
|
config, cache_config, quant_config, prefix=prefix
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
self.norm = nn.LayerNorm(
|
|
config.hidden_size, elementwise_affine=False, bias=False
|
|
)
|
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states"], config.hidden_size
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
"""
|
|
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
|
"""
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embed_input_ids(input_ids)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
|
|
# Apply blocks one-by-one.
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
# shape: (batch_size, seq_len, d_model)
|
|
hidden_states = layer(positions, hidden_states)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({"hidden_states": hidden_states})
|
|
# Apply final layer norm.
|
|
# shape: (batch_size, seq_len or 1, d_model)
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|
"""
|
|
Extremely barebones HF model wrapper.
|
|
"""
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.model = OlmoModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
|
)
|
|
if config.tie_word_embeddings:
|
|
self.lm_head = self.model.embed_tokens
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(
|
|
self,
|
|
skip_prefixes=(
|
|
["lm_head.weight"] if self.config.tie_word_embeddings else None
|
|
),
|
|
)
|
|
return loader.load_weights(weights)
|