Accelerate LLaMA model loading (#234)

This commit is contained in:
JFDuan 2023-08-30 16:00:13 +08:00 committed by GitHub
parent becd7a56f1
commit 0d93f15694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 190 additions and 112 deletions

View File

@ -34,8 +34,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -280,8 +281,7 @@ class AquilaForCausalLM(nn.Module):
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
"gate_proj.weight", "up_proj.weight"
] ]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
@ -309,16 +309,6 @@ class AquilaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs: for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
@ -356,6 +346,11 @@ class AquilaForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights, self._row_parallel_weights,

View File

@ -32,10 +32,12 @@ from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -295,10 +297,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_weights = []
"embed_tokens.weight",
"lm_head.weight",
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
@ -314,16 +313,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "embed_tokens" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "W_pack" in name: if "W_pack" in name:
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
@ -355,6 +344,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights( load_tensor_parallel_weights(
param, param,
loaded_weight, loaded_weight,

View File

@ -31,8 +31,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -224,7 +225,7 @@ class GPT2LMHeadModel(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, def load_weights(self,
@ -261,14 +262,9 @@ class GPT2LMHeadModel(nn.Module):
param = state_dict[name] param = state_dict[name]
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. load_padded_tensor_parallel_vocab(param, loaded_weight,
padded_vocab_size = (param.shape[0] * tensor_model_parallel_rank)
tensor_model_parallel_world_size) continue
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:

View File

@ -32,8 +32,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -252,7 +253,7 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, def load_weights(self,
@ -328,14 +329,9 @@ class GPTBigCodeForCausalLM(nn.Module):
param = state_dict[name] param = state_dict[name]
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. load_padded_tensor_parallel_vocab(param, loaded_weight,
padded_vocab_size = param.shape[ tensor_model_parallel_rank)
0] * tensor_model_parallel_world_size continue
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,

View File

@ -14,8 +14,9 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.sequence import SequenceOutputs from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -225,8 +226,7 @@ class InternLMForCausalLM(nn.Module):
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
"gate_proj.weight", "up_proj.weight"
] ]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
@ -234,8 +234,6 @@ class InternLMForCausalLM(nn.Module):
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
@ -246,14 +244,9 @@ class InternLMForCausalLM(nn.Module):
if "embed_tokens" in name or "lm_head" in name: if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name] param = state_dict[name]
# Consider padding in the vocab size. load_padded_tensor_parallel_vocab(param, loaded_weight,
padded_vocab_size = (param.shape[0] * tensor_model_parallel_rank)
tensor_model_parallel_world_size) continue
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate( for stride_id, att_weight_name in enumerate(

View File

@ -36,8 +36,9 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights) load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -263,15 +264,15 @@ class LlamaForCausalLM(nn.Module):
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
"gate_proj.weight", "up_proj.weight"
] ]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False,
use_safetensor: bool = True):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -288,20 +289,10 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache, use_safetensor):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs: for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
@ -339,6 +330,12 @@ class LlamaForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, self._column_parallel_weights,
self._row_parallel_weights, self._row_parallel_weights,

View File

@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights, load_tensor_parallel_weights,
) )
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
@ -241,7 +242,7 @@ class QWenLMHeadModel(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "lm_head.weight"] _column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights( def load_weights(
@ -259,16 +260,6 @@ class QWenLMHeadModel(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "wte" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "c_attn" in name: if "c_attn" in name:
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
@ -306,6 +297,12 @@ class QWenLMHeadModel(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if "wte" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights( load_tensor_parallel_weights(
param, param,
loaded_weight, loaded_weight,

View File

@ -3,13 +3,19 @@ import filelock
import glob import glob
import json import json
import os import os
from typing import Iterator, List, Optional, Tuple from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open
import numpy as np import numpy as np
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger
logger = init_logger(__name__)
class Disabledtqdm(tqdm): class Disabledtqdm(tqdm):
@ -17,43 +23,118 @@ class Disabledtqdm(tqdm):
super().__init__(*args, **kwargs, disable=True) super().__init__(*args, **kwargs, disable=True)
def hf_model_weights_iterator( def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir = cache_dir if cache_dir is not None else "/tmp" lock_dir = cache_dir if cache_dir is not None else "/tmp"
lock_file_name = model_name_or_path.replace("/", "-") + ".lock" lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensor: bool = False,
):
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensor else "*.bin"
if not is_local: if not is_local:
with lock: # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.bin", allow_patterns=allow_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
if not use_safetensor:
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
]
hf_bin_files = [ if len(hf_weights_files) == 0 and use_safetensor:
x for x in glob.glob(os.path.join(hf_folder, "*.bin")) logger.warning("No *.safetensors files found, "
if not x.endswith("training_args.bin") "fall back to *.bin files")
] return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensor=False)
return hf_folder, hf_weights_files, use_safetensor
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
use_safetensor: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights(
model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor)
if use_np_cache: if use_np_cache:
# Currently np_cache only support *.bin checkpoints
assert use_safetensor is False
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, "np") np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True) os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json") weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock: # Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names = [] weight_names = []
for bin_file in hf_bin_files: for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
param_path = os.path.join(np_folder, name) param_path = os.path.join(np_folder, name)
@ -71,8 +152,14 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f: with open(param_path, "rb") as f:
param = np.load(f) param = np.load(f)
yield name, torch.from_numpy(param) yield name, torch.from_numpy(param)
elif use_safetensor:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys():
param = f.get_slice(name)
yield name, param
else: else:
for bin_file in hf_bin_files: for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
yield name, param yield name, param
@ -80,9 +167,26 @@ def hf_model_weights_iterator(
torch.cuda.empty_cache() torch.cuda.empty_cache()
def load_padded_tensor_parallel_vocab(
param: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
tensor_model_parallel_rank: int,
) -> None:
shard_size = param.shape[0]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]
param[:loaded_weight.shape[0]].copy_(loaded_weight)
def load_tensor_parallel_weights( def load_tensor_parallel_weights(
param: torch.Tensor, param: torch.Tensor,
loaded_weight: torch.Tensor, loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
param_name: str, param_name: str,
column_parallel_weight_names: List[str], column_parallel_weight_names: List[str],
row_parallel_weight_names: List[str], row_parallel_weight_names: List[str],
@ -102,6 +206,11 @@ def load_tensor_parallel_weights(
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[:, start_idx:end_idx] loaded_weight = loaded_weight[:, start_idx:end_idx]
break break
# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]
assert param.shape == loaded_weight.shape, ( assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: " f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}") f"{param.shape} != {loaded_weight.shape}")