Implement LLaMA (#9)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Woosuk Kwon 2023-03-29 21:25:32 -07:00 committed by GitHub
parent a1b3de86cd
commit 80a2f812f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 500 additions and 35 deletions

View File

@ -3,8 +3,10 @@
## Installation
```bash
pip install psutil numpy torch transformers
pip install flash-attn # This may take up to 10 mins.
pip install psutil numpy ray torch
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install flash-attn # This may take up to 20 mins.
pip install -e .
```

View File

@ -61,4 +61,5 @@ class SimpleFrontend:
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
output = output.strip()
print(f'Seq {seq.seq_id}: {output!r}')

357
cacheflow/models/llama.py Normal file
View File

@ -0,0 +1,357 @@
"""1D LLaMA model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
self.register_buffer("inv_freq", inv_freq)
# Create cos and sin embeddings.
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=self.inv_freq.dtype)
sin = emb.sin().to(dtype=self.inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# TODO: Optimize.
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
# TODO: Merge the gate and down linear layers.
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()
def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
x = self.act_fn(gate) * up
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5
# TODO: Merge the QKV linear layers.
self.q_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.k_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.v_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
# FIXME(woosuk): Rename this.
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
# Apply rotrary embedding.
# TODO: Optimize.
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
cos, sin = self.rotary_emb(positions)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = LlamaModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
if not os.path.isfile(os.path.join(model_name, "config.json")):
raise ValueError("LLaMA model's model_name has to be a path"
"to the huggingface model's directory.")
path = os.path.join(model_name, f"np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
if os.path.exists(test_weight_path):
return path
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path

View File

@ -15,11 +15,30 @@ class CacheFlowMemoryAnalyzer:
) -> int:
raise NotImplementedError()
def get_workspace_size(self) -> int:
return 1 * _GiB
def get_cache_block_size(self) -> int:
raise NotImplementedError()
def get_max_num_cpu_blocks(
self,
memory_utilization: float,
swap_space: int,
) -> int:
raise NotImplementedError()
swap_space = swap_space * _GiB
cpu_memory = self.cpu_memory
if swap_space > 0.8 * cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self.get_cache_block_size()
return max_num_blocks
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
@ -52,9 +71,9 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
if self.embedding_size != self.vocab_size:
if self.embedding_size != self.hidden_size:
# Project in/out.
word_embedding += 2 * self.embedding_size * self.vocab_size
word_embedding += 2 * self.embedding_size * self.hidden_size
position_embedding = self.max_position * self.hidden_size
ln1 = 2 * self.hidden_size
@ -89,15 +108,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
def _get_workspace_size(self) -> int:
return 1 * _GiB
def _get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.num_heads * self.head_size
value_cache_block = self.block_size * self.num_heads * self.head_size
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
@ -112,26 +131,111 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self._get_workspace_size()
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
max_num_blocks = max_cache_size // self._get_cache_block_size()
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks
def get_max_num_cpu_blocks(
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
swap_space: int,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
# FIXME
self.max_position = 2048
def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
position_embedding = self.max_position * self.hidden_size
# NOTE: LLaMA does not have bias terms.
ln1 = self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot = self.max_position * self.head_size
mha = ln1 + q + k + v + out + rot
ln2 = self.hidden_size
gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
ffn = ln2 + gate + down + up
total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def _get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
swap_space = swap_space * _GiB
if swap_space > 0.8 * self.cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 80% of the available memory '
f'({self.cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * self.cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({self.cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self._get_cache_block_size()
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory = self.gpu_memory
usable_memory = int(memory_utilization * gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks

View File

@ -6,16 +6,20 @@ import torch.nn as nn
from transformers import AutoConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
from cacheflow.models.llama import LlamaForCausalLM
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype
_MODELS = {
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
}
_MEMORY_ANALYZERS = {
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
}
@ -31,7 +35,7 @@ def get_model(
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
# Download model weights if it's not cached.
weights_dir = model_class.download_weights(model_name, path=path)
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.

View File

@ -299,7 +299,7 @@ class OPTForCausalLM(nn.Module):
param.data.copy_(loaded_weight)
@staticmethod
def download_weights(model_name: str, path: str):
def get_weights(model_name: str, path: str):
path = os.path.join(path, f"{model_name}-np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
@ -316,11 +316,8 @@ class OPTForCausalLM(nn.Module):
cache_dir=os.path.join(path, "cache"))
bin_files = glob.glob(os.path.join(folder, "*.bin"))
if "/" in model_name:
model_name = model_name.split("/")[1].lower()
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file)
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
if name.startswith("decoder."):
name = "model." + name

View File

@ -39,7 +39,7 @@ class Sampler(nn.Module):
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs)
logprobs = torch.log(probs, out=logits)
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)