Enable safetensors loading for all models (#974)

This commit is contained in:
Zhuohan Li 2023-09-07 15:49:52 -07:00 committed by GitHub
parent c07ece5ca4
commit c957c741d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 143 additions and 83 deletions

View File

@ -24,9 +24,16 @@ class ModelConfig:
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
This can increase the disk usage by up to 2x.
use_dummy_weights: Use dummy values for model weights (for profiling).
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
@ -40,8 +47,7 @@ class ModelConfig:
tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str],
use_np_weights: bool,
use_dummy_weights: bool,
load_format: str,
dtype: str,
seed: int,
) -> None:
@ -50,14 +56,24 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
self.load_format = load_format
self.seed = seed
self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format()
self._verify_tokenizer_mode()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
if load_format not in [
"auto", "pt", "safetensors", "npcache", "dummy"
]:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:

View File

@ -15,8 +15,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
load_format: str = 'auto'
dtype: str = 'auto'
seed: int = 0
worker_use_ray: bool = False
@ -65,14 +64,21 @@ class EngineArgs:
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--use-np-weights',
action='store_true',
help='save a numpy copy of model weights for '
'faster loading. This can increase the disk '
'usage by up to 2x.')
parser.add_argument('--use-dummy-weights',
action='store_true',
help='use dummy values for model weights')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
# TODO(woosuk): Support FP32.
parser.add_argument(
'--dtype',
@ -146,9 +152,8 @@ class EngineArgs:
# Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype,
self.seed)
self.download_dir, self.load_format,
self.dtype, self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)

View File

@ -76,9 +76,8 @@ class LLMEngine:
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.

View File

@ -56,7 +56,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config)
if model_config.use_dummy_weights:
if model_config.load_format == "dummy":
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.use_np_weights)
model_config.load_format)
model = model.cuda()
return model.eval()

View File

@ -288,7 +288,7 @@ class AquilaForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -305,7 +305,7 @@ class AquilaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -35,8 +35,8 @@ from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -303,16 +303,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size

View File

@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.

View File

@ -31,7 +31,8 @@ from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size()
loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2,

View File

@ -32,8 +32,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -231,14 +231,14 @@ class GPT2LMHeadModel(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
@ -251,6 +251,8 @@ class GPT2LMHeadModel(nn.Module):
if not name.startswith("transformer."):
name = "transformer." + name
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:

View File

@ -33,8 +33,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@ -259,14 +259,14 @@ class GPTBigCodeForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
@ -295,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)

View File

@ -222,11 +222,11 @@ class GPTJForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "attn.bias" in name or "attn.masked_bias" in name:
continue

View File

@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue

View File

@ -233,12 +233,12 @@ class InternLMForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -271,8 +271,7 @@ class LlamaForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
use_safetensor: bool = True):
load_format: str = "auto"):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -289,7 +288,7 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache, use_safetensor):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -10,7 +10,8 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module):
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)

View File

@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name:
continue

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor,
hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights,
@ -249,17 +250,19 @@ class QWenLMHeadModel(nn.Module):
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
load_format: str = "auto",
):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size

View File

@ -81,11 +81,12 @@ def convert_bin_to_safetensor_file(
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensor: bool = False,
use_safetensors: bool = False,
fall_back_to_pt: bool = True,
):
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensor else "*.bin"
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
if not is_local:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
@ -97,32 +98,53 @@ def prepare_hf_model_weights(
else:
hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
if not use_safetensor:
if not use_safetensors:
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
]
if len(hf_weights_files) == 0 and use_safetensor:
logger.warning("No *.safetensors files found, "
"fall back to *.bin files")
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensor=False)
return hf_folder, hf_weights_files, use_safetensor
use_safetensors=False,
fall_back_to_pt=False)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
use_safetensor: bool = False,
load_format: str = "auto",
) -> Iterator[Tuple[str, torch.Tensor]]:
hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights(
model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor)
use_safetensors = False
use_np_cache = False
fall_back_to_pt = False
if load_format == "auto":
use_safetensors = True
fall_back_to_pt = True
elif load_format == "safetensors":
use_safetensors = True
elif load_format == "pt":
pass
elif load_format == "npcache":
use_np_cache = True
else:
raise ValueError(f"Unknown load_format: {load_format}")
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt)
if use_np_cache:
# Currently np_cache only support *.bin checkpoints
assert use_safetensor is False
assert use_safetensors is False
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
@ -152,7 +174,7 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
elif use_safetensor:
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys():
@ -167,6 +189,21 @@ def hf_model_weights_iterator(
torch.cuda.empty_cache()
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
def load_padded_tensor_parallel_vocab(
param: torch.Tensor,
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
@ -176,11 +213,7 @@ def load_padded_tensor_parallel_vocab(
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
param[:loaded_weight.shape[0]].copy_(loaded_weight)
@ -207,10 +240,7 @@ def load_tensor_parallel_weights(
loaded_weight = loaded_weight[:, start_idx:end_idx]
break
# convert PySafeSlice object to torch.Tensor
if not isinstance(loaded_weight, torch.Tensor):
loaded_weight = loaded_weight[:]
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")