mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 14:35:44 +08:00
398 lines
14 KiB
Python
398 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
|
# Copyright 2023 The vLLM team.
|
|
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
|
|
|
from collections.abc import Iterable
|
|
from itertools import islice
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import GPT2Config
|
|
|
|
from vllm.attention import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed.parallel_state import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from ..layers.pooler import DispatchPooler, Pooler
|
|
from .interfaces import SupportsCrossEncoding, SupportsPP
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
|
|
class GPT2Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: GPT2Config,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
total_num_heads = config.num_attention_heads
|
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
|
assert total_num_heads % tensor_model_parallel_world_size == 0
|
|
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
|
self.head_dim = self.hidden_size // total_num_heads
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
self.c_attn = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
total_num_heads,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.c_attn",
|
|
)
|
|
self.c_proj = RowParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.c_proj",
|
|
)
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
scale=self.scale,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.c_attn(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
attn_output = self.attn(q, k, v)
|
|
attn_output, _ = self.c_proj(attn_output)
|
|
return attn_output
|
|
|
|
|
|
class GPT2MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
intermediate_size: int,
|
|
config: GPT2Config,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
self.c_fc = ColumnParallelLinear(
|
|
hidden_size,
|
|
intermediate_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.c_fc",
|
|
)
|
|
self.c_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.c_proj",
|
|
)
|
|
self.act = get_act_fn(config.activation_function)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.c_fc(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states, _ = self.c_proj(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GPT2Block(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: GPT2Config,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
|
|
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
self.attn = GPT2Attention(
|
|
config, cache_config, quant_config, prefix=f"{prefix}.attn"
|
|
)
|
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
residual = hidden_states
|
|
hidden_states = self.ln_1(hidden_states)
|
|
attn_output = self.attn(hidden_states=hidden_states)
|
|
# residual connection
|
|
hidden_states = attn_output + residual
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.ln_2(hidden_states)
|
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
|
# residual connection
|
|
hidden_states = residual + feed_forward_hidden_states
|
|
return hidden_states
|
|
|
|
|
|
@support_torch_compile
|
|
class GPT2Model(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
assert not config.add_cross_attention
|
|
assert not config.scale_attn_by_inverse_layer_idx
|
|
assert not config.reorder_and_upcast_attn
|
|
self.embed_dim = config.hidden_size
|
|
self.wte = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
self.embed_dim,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.wte",
|
|
)
|
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
|
self.start_layer, self.end_layer, self.h = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: GPT2Block(config, cache_config, quant_config, prefix=prefix),
|
|
prefix=f"{prefix}.h",
|
|
)
|
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states"], config.n_embd
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.wte(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
|
position_embeds = self.wpe(position_ids)
|
|
hidden_states = inputs_embeds + position_embeds
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
|
|
for layer in islice(self.h, self.start_layer, self.end_layer):
|
|
hidden_states = layer(hidden_states)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({"hidden_states": hidden_states})
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
|
# Skip attention mask.
|
|
# NOTE: "c_attn.bias" should not be skipped.
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
|
# Because of this, we need to transpose the weights.
|
|
# Note(zhuohan): the logic below might break quantized models.
|
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
|
if conv1d_weight_name not in name:
|
|
continue
|
|
if not name.endswith(".weight"):
|
|
continue
|
|
loaded_weight = loaded_weight.t()
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class GPT2LMHeadModel(nn.Module, SupportsPP):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.transformer = GPT2Model(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
|
|
)
|
|
self.lm_head = ParallelLMHead(
|
|
self.config.vocab_size,
|
|
self.config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.lm_head",
|
|
)
|
|
if self.config.tie_word_embeddings:
|
|
self.lm_head = self.lm_head.tie_weights(self.transformer.wte)
|
|
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.transformer.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.transformer.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
hidden_states = self.transformer(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
weights = _add_transformer_prefix(weights)
|
|
return loader.load_weights(weights)
|
|
|
|
|
|
class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|
"""GPT2 Model for sequence classification.
|
|
|
|
This class expands GPT2Model with pooling and score functions - last token
|
|
is being used for classification.
|
|
|
|
Attributes:
|
|
transformer: An instance of GPT2Model used for forward operations.
|
|
score: A layer for calculating logits.
|
|
_pooler: An instance of Pooler used for pooling operations.
|
|
"""
|
|
|
|
is_pooling_model = True
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.transformer = GPT2Model(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")
|
|
)
|
|
self.score = nn.Linear(
|
|
config.n_embd,
|
|
config.num_labels,
|
|
bias=False,
|
|
dtype=vllm_config.model_config.head_dtype,
|
|
)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
assert pooler_config is not None
|
|
|
|
self.pooler = DispatchPooler(
|
|
{
|
|
"token_classify": Pooler.for_token_classify(
|
|
pooler_config, classifier=self.score
|
|
),
|
|
"classify": Pooler.for_classify(
|
|
pooler_config, classifier=self.score, act_fn="classify"
|
|
),
|
|
"score": Pooler.for_classify(
|
|
pooler_config, classifier=self.score, act_fn="score"
|
|
),
|
|
}
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.transformer.get_input_embeddings(input_ids)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.transformer(
|
|
input_ids=input_ids,
|
|
position_ids=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
intermediate_tensors=intermediate_tensors,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
def _add_transformer_prefix(
|
|
weights: Iterable[tuple[str, torch.Tensor]],
|
|
) -> Iterable[tuple[str, torch.Tensor]]:
|
|
for name, tensor in weights:
|
|
if not name.startswith("transformer.") and not name.startswith("lm_head"):
|
|
name = "transformer." + name
|
|
yield name, tensor
|