mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
Add support for GPT-2 (#60)
This commit is contained in:
parent
130d5fd8c7
commit
e548c1488a
265
cacheflow/models/gpt2.py
Normal file
265
cacheflow/models/gpt2.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""1D GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import GPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPT2Config,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
act_fn = config.activation_function
|
||||
if act_fn != "gelu_new":
|
||||
raise ValueError(f"Unsupported activation: {act_fn}. "
|
||||
"GPT-2 only supports gelu_new for now.")
|
||||
self.act = torch.nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2Block(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert config.scale_attn_by_inverse_layer_idx == False
|
||||
assert config.reorder_and_upcast_attn == False
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
# to 50304 in order to make it divisible by 64.
|
||||
# This improves performance since GPUs are faster if the dimension
|
||||
# is divisible by 64. In addition, it allows us to shard the embedding
|
||||
# layer across 2, 4, 8, or more GPUs.
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[GPT2Block(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = GPT2Model(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
name = "transformer." + name
|
||||
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights)
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
param.data.uniform_(-1e-3, 1e-3)
|
||||
@ -173,7 +173,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -205,8 +205,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, num_heads * head_size], while the
|
||||
# required shape is [3 * num_heads * head_size, num_heads * head_size].
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||
@ -218,11 +218,11 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
if 'query_key_value.weight' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size).contiguous()
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif 'query_key_value.bias' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1).contiguous()
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
|
||||
@ -192,7 +192,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -72,6 +72,76 @@ class CacheFlowMemoryAnalyzer:
|
||||
return max_num_blocks
|
||||
|
||||
|
||||
class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
gpu_memory: int,
|
||||
cpu_memory: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.block_size = block_size
|
||||
self.dtype = dtype
|
||||
self.gpu_memory = gpu_memory
|
||||
self.cpu_memory = cpu_memory
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.hidden_size // self.num_heads
|
||||
self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.max_position = config.max_position_embeddings
|
||||
|
||||
def get_param_size(self) -> int:
|
||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||
position_embedding = self.max_position * self.hidden_size
|
||||
|
||||
ln1 = 2 * self.hidden_size
|
||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
mha = ln1 + q + k + v + out
|
||||
|
||||
ln2 = 2 * self.hidden_size
|
||||
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
||||
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
ffn = ln2 + ffn1 + ffn2
|
||||
|
||||
total = (word_embedding + position_embedding +
|
||||
self.num_layers * (mha + ffn))
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * total
|
||||
|
||||
def get_max_act_size(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
) -> int:
|
||||
# NOTE: We approxmiately calculate the maximum activation size by
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
||||
# Double the activation size for input and output.
|
||||
max_act = 2 * (max(qkv, ffn) + residual)
|
||||
# Size of output logits.
|
||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
||||
max_act = max(max_act, output_logits)
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * max_act
|
||||
|
||||
|
||||
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -5,9 +5,11 @@ import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||
from cacheflow.models.gpt2 import GPT2LMHeadModel
|
||||
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from cacheflow.models.llama import LlamaForCausalLM
|
||||
from cacheflow.models.opt import OPTForCausalLM
|
||||
@ -15,6 +17,7 @@ from cacheflow.models.utils import get_torch_dtype
|
||||
|
||||
|
||||
_MODELS = {
|
||||
'gpt2': GPT2LMHeadModel,
|
||||
'llama': LlamaForCausalLM,
|
||||
'opt': OPTForCausalLM,
|
||||
'stablelm': GPTNeoXForCausalLM,
|
||||
@ -22,6 +25,7 @@ _MODELS = {
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
'gpt2': GPT2MemoryAnalyzer,
|
||||
'llama': LlamaMemoryAnalyzer,
|
||||
'opt': OPTMemoryAnalyzer,
|
||||
'stablelm': GPTNeoXMemoryAnalyzer,
|
||||
|
||||
@ -234,7 +234,7 @@ class OPTForCausalLM(nn.Module):
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.sampler = Sampler()
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -11,8 +11,9 @@ from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_pa
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, vocab_size: int) -> None:
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -26,6 +27,8 @@ class Sampler(nn.Module):
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab.
|
||||
logits = logits[:, :self.vocab_size]
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user