mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:54:29 +08:00
parent
a1b3de86cd
commit
80a2f812f1
@ -3,8 +3,10 @@
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install psutil numpy torch transformers
|
||||
pip install flash-attn # This may take up to 10 mins.
|
||||
pip install psutil numpy ray torch
|
||||
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
|
||||
pip install sentencepiece # Required for LlamaTokenizer.
|
||||
pip install flash-attn # This may take up to 20 mins.
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
@ -61,4 +61,5 @@ class SimpleFrontend:
|
||||
for seq in seq_group.seqs:
|
||||
token_ids = seq.get_token_ids()
|
||||
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
output = output.strip()
|
||||
print(f'Seq {seq.seq_id}: {output!r}')
|
||||
|
||||
357
cacheflow/models/llama.py
Normal file
357
cacheflow/models/llama.py
Normal file
@ -0,0 +1,357 @@
|
||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
||||
import os
|
||||
import glob
|
||||
import filelock
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import LlamaConfig
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=self.inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=self.inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
self.register_buffer("sin_cached", sin, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin):
|
||||
# TODO: Optimize.
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO: Merge the gate and down linear layers.
|
||||
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
perform_initialization=False)
|
||||
assert hidden_act == 'silu'
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
gate, _ = self.gate_proj(x)
|
||||
up, _ = self.up_proj(x)
|
||||
x = self.act_fn(gate) * up
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
# TODO: Merge the QKV linear layers.
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
|
||||
# FIXME(woosuk): Rename this.
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
k, _ = self.k_proj(hidden_states)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
|
||||
# Apply rotrary embedding.
|
||||
# TODO: Optimize.
|
||||
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
|
||||
cos, sin = self.rotary_emb(positions)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
|
||||
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
|
||||
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = LlamaModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"q_proj.weight", "k_proj.weight",
|
||||
"v_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, weights_path: str):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
|
||||
name)))
|
||||
for p in self._column_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
@staticmethod
|
||||
def get_weights(model_name: str, path: str):
|
||||
if not os.path.isfile(os.path.join(model_name, "config.json")):
|
||||
raise ValueError("LLaMA model's model_name has to be a path"
|
||||
"to the huggingface model's directory.")
|
||||
path = os.path.join(model_name, f"np")
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
lock_path = os.path.join(path, "file_lock")
|
||||
lock = filelock.FileLock(lock_path)
|
||||
|
||||
with lock:
|
||||
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
|
||||
if os.path.exists(test_weight_path):
|
||||
return path
|
||||
|
||||
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
|
||||
|
||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in tqdm(state.items(), leave=False):
|
||||
param_path = os.path.join(path, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
|
||||
return path
|
||||
@ -15,11 +15,30 @@ class CacheFlowMemoryAnalyzer:
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_workspace_size(self) -> int:
|
||||
return 1 * _GiB
|
||||
|
||||
def get_cache_block_size(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_max_num_cpu_blocks(
|
||||
self,
|
||||
memory_utilization: float,
|
||||
swap_space: int,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
swap_space = swap_space * _GiB
|
||||
cpu_memory = self.cpu_memory
|
||||
if swap_space > 0.8 * cpu_memory:
|
||||
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
|
||||
'takes more than 80% of the available memory '
|
||||
f'({cpu_memory / _GiB:.2f} GiB).'
|
||||
'Please check the swap space size.')
|
||||
if swap_space > 0.5 * cpu_memory:
|
||||
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
|
||||
'takes more than 50% of the available memory '
|
||||
f'({cpu_memory / _GiB:.2f} GiB).'
|
||||
'This may slow the system performance.')
|
||||
max_num_blocks = swap_space // self.get_cache_block_size()
|
||||
return max_num_blocks
|
||||
|
||||
|
||||
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
@ -52,9 +71,9 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
|
||||
def _get_param_size(self) -> int:
|
||||
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
|
||||
if self.embedding_size != self.vocab_size:
|
||||
if self.embedding_size != self.hidden_size:
|
||||
# Project in/out.
|
||||
word_embedding += 2 * self.embedding_size * self.vocab_size
|
||||
word_embedding += 2 * self.embedding_size * self.hidden_size
|
||||
position_embedding = self.max_position * self.hidden_size
|
||||
|
||||
ln1 = 2 * self.hidden_size
|
||||
@ -89,15 +108,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
||||
# Double the activation size for input and output.
|
||||
max_act = 2 * (max(qkv, ffn) + residual)
|
||||
# Size of output logits.
|
||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
||||
max_act = max(max_act, output_logits)
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * max_act
|
||||
|
||||
def _get_workspace_size(self) -> int:
|
||||
return 1 * _GiB
|
||||
|
||||
def _get_cache_block_size(self) -> int:
|
||||
key_cache_block = self.block_size * self.num_heads * self.head_size
|
||||
value_cache_block = self.block_size * self.num_heads * self.head_size
|
||||
def get_cache_block_size(self) -> int:
|
||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
||||
value_cache_block = key_cache_block
|
||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * total
|
||||
@ -112,26 +131,111 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
|
||||
param_size = self._get_param_size()
|
||||
act_size = self._get_max_act_size(max_num_batched_tokens)
|
||||
workspace_size = self._get_workspace_size()
|
||||
workspace_size = self.get_workspace_size()
|
||||
|
||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
||||
max_num_blocks = max_cache_size // self._get_cache_block_size()
|
||||
if max_cache_size <= 0:
|
||||
raise RuntimeError('Not enough GPU memory.')
|
||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
||||
return max_num_blocks
|
||||
|
||||
def get_max_num_cpu_blocks(
|
||||
|
||||
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
swap_space: int,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
gpu_memory: int,
|
||||
cpu_memory: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.block_size = block_size
|
||||
self.dtype = dtype
|
||||
self.gpu_memory = gpu_memory
|
||||
self.cpu_memory = cpu_memory
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.hidden_size // self.num_heads
|
||||
self.ffn_size = config.intermediate_size
|
||||
self.vocab_size = config.vocab_size
|
||||
# FIXME
|
||||
self.max_position = 2048
|
||||
|
||||
def _get_param_size(self) -> int:
|
||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||
position_embedding = self.max_position * self.hidden_size
|
||||
|
||||
# NOTE: LLaMA does not have bias terms.
|
||||
ln1 = self.hidden_size
|
||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
||||
# Rotary embedding.
|
||||
# TODO(woosuk): Share the rotary embedding between layers.
|
||||
rot = self.max_position * self.head_size
|
||||
mha = ln1 + q + k + v + out + rot
|
||||
|
||||
ln2 = self.hidden_size
|
||||
gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
||||
down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
|
||||
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
||||
ffn = ln2 + gate + down + up
|
||||
|
||||
total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * total
|
||||
|
||||
def _get_max_act_size(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
) -> int:
|
||||
swap_space = swap_space * _GiB
|
||||
if swap_space > 0.8 * self.cpu_memory:
|
||||
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
|
||||
'takes more than 80% of the available memory '
|
||||
f'({self.cpu_memory / _GiB:.2f} GiB).'
|
||||
'Please check the swap space size.')
|
||||
if swap_space > 0.5 * self.cpu_memory:
|
||||
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
|
||||
'takes more than 50% of the available memory '
|
||||
f'({self.cpu_memory / _GiB:.2f} GiB).'
|
||||
'This may slow the system performance.')
|
||||
max_num_blocks = swap_space // self._get_cache_block_size()
|
||||
# NOTE: We approxmiately calculate the maximum activation size by
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
||||
# Double the activation size for input and output.
|
||||
max_act = 2 * (max(qkv, ffn) + residual)
|
||||
# Size of output logits.
|
||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
||||
max_act = max(max_act, output_logits)
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * max_act
|
||||
|
||||
def get_cache_block_size(self) -> int:
|
||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
||||
value_cache_block = key_cache_block
|
||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * total
|
||||
|
||||
def get_max_num_gpu_blocks(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
memory_utilization: float = 0.95,
|
||||
) -> int:
|
||||
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
||||
gpu_memory = self.gpu_memory
|
||||
usable_memory = int(memory_utilization * gpu_memory)
|
||||
|
||||
param_size = self._get_param_size()
|
||||
act_size = self._get_max_act_size(max_num_batched_tokens)
|
||||
workspace_size = self.get_workspace_size()
|
||||
|
||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
||||
if max_cache_size <= 0:
|
||||
raise RuntimeError('Not enough GPU memory.')
|
||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
||||
return max_num_blocks
|
||||
|
||||
@ -6,16 +6,20 @@ import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||
from cacheflow.models.llama import LlamaForCausalLM
|
||||
from cacheflow.models.opt import OPTForCausalLM
|
||||
from cacheflow.models.utils import get_torch_dtype
|
||||
|
||||
|
||||
_MODELS = {
|
||||
'llama': LlamaForCausalLM,
|
||||
'opt': OPTForCausalLM,
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
'llama': LlamaMemoryAnalyzer,
|
||||
'opt': OPTMemoryAnalyzer,
|
||||
}
|
||||
|
||||
@ -31,7 +35,7 @@ def get_model(
|
||||
for model_class_name, model_class in _MODELS.items():
|
||||
if model_class_name in model_name:
|
||||
# Download model weights if it's not cached.
|
||||
weights_dir = model_class.download_weights(model_name, path=path)
|
||||
weights_dir = model_class.get_weights(model_name, path=path)
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
|
||||
@ -299,7 +299,7 @@ class OPTForCausalLM(nn.Module):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
@staticmethod
|
||||
def download_weights(model_name: str, path: str):
|
||||
def get_weights(model_name: str, path: str):
|
||||
path = os.path.join(path, f"{model_name}-np")
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@ -316,11 +316,8 @@ class OPTForCausalLM(nn.Module):
|
||||
cache_dir=os.path.join(path, "cache"))
|
||||
bin_files = glob.glob(os.path.join(folder, "*.bin"))
|
||||
|
||||
if "/" in model_name:
|
||||
model_name = model_name.split("/")[1].lower()
|
||||
|
||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
||||
state = torch.load(bin_file)
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in tqdm(state.items(), leave=False):
|
||||
if name.startswith("decoder."):
|
||||
name = "model." + name
|
||||
|
||||
@ -39,7 +39,7 @@ class Sampler(nn.Module):
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
logprobs = torch.log(probs)
|
||||
logprobs = torch.log(probs, out=logits)
|
||||
|
||||
# Apply top-p truncation.
|
||||
top_ps = _get_top_ps(input_metadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user