diff --git a/vllm/config.py b/vllm/config.py index 2e8d58411181..06ade925fe51 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -24,9 +24,16 @@ class ModelConfig: downloading the model and tokenizer. download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. - use_np_weights: Save a numpy copy of model weights for faster loading. - This can increase the disk usage by up to 2x. - use_dummy_weights: Use dummy values for model weights (for profiling). + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -40,8 +47,7 @@ class ModelConfig: tokenizer_mode: str, trust_remote_code: bool, download_dir: Optional[str], - use_np_weights: bool, - use_dummy_weights: bool, + load_format: str, dtype: str, seed: int, ) -> None: @@ -50,14 +56,24 @@ class ModelConfig: self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code self.download_dir = download_dir - self.use_np_weights = use_np_weights - self.use_dummy_weights = use_dummy_weights + self.load_format = load_format self.seed = seed self.hf_config = get_config(model, trust_remote_code) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self._verify_load_format() self._verify_tokenizer_mode() + def _verify_load_format(self) -> None: + load_format = self.load_format.lower() + if load_format not in [ + "auto", "pt", "safetensors", "npcache", "dummy" + ]: + raise ValueError( + f"Unknown load format: {self.load_format}. Must be one of " + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + self.load_format = load_format + def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow"]: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 99fe593b4cb0..b775ec089dcf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -15,8 +15,7 @@ class EngineArgs: tokenizer_mode: str = 'auto' trust_remote_code: bool = False download_dir: Optional[str] = None - use_np_weights: bool = False - use_dummy_weights: bool = False + load_format: str = 'auto' dtype: str = 'auto' seed: int = 0 worker_use_ray: bool = False @@ -65,14 +64,21 @@ class EngineArgs: help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') - parser.add_argument('--use-np-weights', - action='store_true', - help='save a numpy copy of model weights for ' - 'faster loading. This can increase the disk ' - 'usage by up to 2x.') - parser.add_argument('--use-dummy-weights', - action='store_true', - help='use dummy values for model weights') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') # TODO(woosuk): Support FP32. parser.add_argument( '--dtype', @@ -146,9 +152,8 @@ class EngineArgs: # Initialize the configs. model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, - self.download_dir, self.use_np_weights, - self.use_dummy_weights, self.dtype, - self.seed) + self.download_dir, self.load_format, + self.dtype, self.seed) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4ea443d8451d..784c0caf1999 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -76,9 +76,8 @@ class LLMEngine: f"tokenizer_mode={model_config.tokenizer_mode}, " f"trust_remote_code={model_config.trust_remote_code}, " f"dtype={model_config.dtype}, " - f"use_dummy_weights={model_config.use_dummy_weights}, " f"download_dir={model_config.download_dir!r}, " - f"use_np_weights={model_config.use_np_weights}, " + f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 4cddf8c360ee..71a326883b62 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -56,7 +56,7 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Create a model instance. # The weights will be initialized as empty tensors. model = model_class(model_config.hf_config) - if model_config.use_dummy_weights: + if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. @@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module: else: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, - model_config.use_np_weights) + model_config.load_format) model = model.cuda() return model.eval() diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 189a31d36ca5..dcd849d722f2 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -288,7 +288,7 @@ class AquilaForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -305,7 +305,7 @@ class AquilaForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 9ebe37d66a02..77ada1e76522 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -35,8 +35,8 @@ from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, PagedAttentionWithALiBi) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) + convert_pyslice_to_tensor, hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -303,16 +303,18 @@ class BaiChuanBaseForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "rotary_emb.inv_freq" in name: continue + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + if "W_pack" in name: total_num_heads = self.config.num_attention_heads hidden_size = self.config.hidden_size diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index c7e3ecf15b42..e17d8d075e14 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if name == "lm_head.weight": # Since hidden_states are parallelized, we need to # load lm_head.weight in parallel. diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8a2a9fedfff0..883faf636269 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -31,7 +31,8 @@ from vllm.model_executor.layers.attention import (PagedAttention, PagedAttentionWithALiBi, PagedAttentionWithRoPE) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, +from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, + hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tp_size = (get_tensor_model_parallel_world_size()) tp_rank = get_tensor_model_parallel_rank() @@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "query_key_value" in name: + loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight_size = loaded_weight.size() loaded_weight = loaded_weight.view( total_num_kv_heads, num_query_heads_per_kv_head + 2, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 18a11b490755..ba1ee05b6f0e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -32,8 +32,8 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) + convert_pyslice_to_tensor, hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -231,14 +231,14 @@ class GPT2LMHeadModel(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. @@ -251,6 +251,8 @@ class GPT2LMHeadModel(nn.Module): if not name.startswith("transformer."): name = "transformer." + name + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 91c432bf8b0d..9d6d091cb1f1 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -33,8 +33,8 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) + convert_pyslice_to_tensor, hf_model_weights_iterator, + load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -259,14 +259,14 @@ class GPTBigCodeForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. @@ -295,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module): head_start = tensor_model_parallel_rank * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads + loaded_weight = convert_pyslice_to_tensor(loaded_weight) wq, wk, wv = torch.split( loaded_weight, [hidden_size, total_kv_size, total_kv_size], dim=0) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 0c9a7ef9d736..456b192322ed 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -222,11 +222,11 @@ class GPTJForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "attn.bias" in name or "attn.masked_bias" in name: continue diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 531983989963..454600854a86 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 50b26fcd3688..5cd68541141b 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -233,12 +233,12 @@ class InternLMForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 62da5532187e..d217f447a498 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -271,8 +271,7 @@ class LlamaForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False, - use_safetensor: bool = True): + load_format: str = "auto"): tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -289,7 +288,7 @@ class LlamaForCausalLM(nn.Module): state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache, use_safetensor): + model_name_or_path, cache_dir, load_format): if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 7a75fac62b1b..4cd88900a4d9 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -10,7 +10,8 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithALiBi from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, +from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, + hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "Wqkv" in name: # NOTE(woosuk): MPT's fused QKV has the shape of # [3 * num_heads * head_size, hidden_size]. @@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module): num_heads = total_num_heads // tp_world_size head_start = tp_rank * num_heads head_end = (tp_rank + 1) * num_heads - + loaded_weight = convert_pyslice_to_tensor(loaded_weight) if name.endswith(".weight"): loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 9bd503ae4209..508083df9297 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module): def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False): + load_format: str = "auto"): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "lm_head.weight" in name: continue diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d511ef96b61f..a3557b5818f5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import ( + convert_pyslice_to_tensor, hf_model_weights_iterator, load_padded_tensor_parallel_vocab, load_tensor_parallel_weights, @@ -249,17 +250,19 @@ class QWenLMHeadModel(nn.Module): self, model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False, + load_format: str = "auto", ): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, use_np_cache): + model_name_or_path, cache_dir, load_format): if "rotary_emb.inv_freq" in name: continue + loaded_weight = convert_pyslice_to_tensor(loaded_weight) + if "c_attn" in name: total_num_heads = self.config.num_attention_heads hidden_size = self.config.hidden_size diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 23844470faa5..578a1392845b 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -81,11 +81,12 @@ def convert_bin_to_safetensor_file( def prepare_hf_model_weights( model_name_or_path: str, cache_dir: Optional[str] = None, - use_safetensor: bool = False, + use_safetensors: bool = False, + fall_back_to_pt: bool = True, ): # Download model weights from huggingface. is_local = os.path.isdir(model_name_or_path) - allow_patterns = "*.safetensors" if use_safetensor else "*.bin" + allow_patterns = "*.safetensors" if use_safetensors else "*.bin" if not is_local: # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. @@ -97,32 +98,53 @@ def prepare_hf_model_weights( else: hf_folder = model_name_or_path hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) - if not use_safetensor: + if not use_safetensors: hf_weights_files = [ x for x in hf_weights_files if not x.endswith("training_args.bin") ] - if len(hf_weights_files) == 0 and use_safetensor: - logger.warning("No *.safetensors files found, " - "fall back to *.bin files") + if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: return prepare_hf_model_weights(model_name_or_path, cache_dir=cache_dir, - use_safetensor=False) - return hf_folder, hf_weights_files, use_safetensor + use_safetensors=False, + fall_back_to_pt=False) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors def hf_model_weights_iterator( model_name_or_path: str, cache_dir: Optional[str] = None, - use_np_cache: bool = False, - use_safetensor: bool = False, + load_format: str = "auto", ) -> Iterator[Tuple[str, torch.Tensor]]: - hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights( - model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor) + use_safetensors = False + use_np_cache = False + fall_back_to_pt = False + if load_format == "auto": + use_safetensors = True + fall_back_to_pt = True + elif load_format == "safetensors": + use_safetensors = True + elif load_format == "pt": + pass + elif load_format == "npcache": + use_np_cache = True + else: + raise ValueError(f"Unknown load_format: {load_format}") + + hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + use_safetensors=use_safetensors, + fall_back_to_pt=fall_back_to_pt) if use_np_cache: # Currently np_cache only support *.bin checkpoints - assert use_safetensor is False + assert use_safetensors is False # Convert the model weights from torch tensors to numpy arrays for # faster loading. @@ -152,7 +174,7 @@ def hf_model_weights_iterator( with open(param_path, "rb") as f: param = np.load(f) yield name, torch.from_numpy(param) - elif use_safetensor: + elif use_safetensors: for st_file in hf_weights_files: with safe_open(st_file, framework="pt") as f: for name in f.keys(): @@ -167,6 +189,21 @@ def hf_model_weights_iterator( torch.cuda.empty_cache() +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + def load_padded_tensor_parallel_vocab( param: torch.Tensor, loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` @@ -176,11 +213,7 @@ def load_padded_tensor_parallel_vocab( start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size loaded_weight = loaded_weight[start_idx:end_idx] - - # convert PySafeSlice object to torch.Tensor - if not isinstance(loaded_weight, torch.Tensor): - loaded_weight = loaded_weight[:] - + loaded_weight = convert_pyslice_to_tensor(loaded_weight) param[:loaded_weight.shape[0]].copy_(loaded_weight) @@ -207,10 +240,7 @@ def load_tensor_parallel_weights( loaded_weight = loaded_weight[:, start_idx:end_idx] break - # convert PySafeSlice object to torch.Tensor - if not isinstance(loaded_weight, torch.Tensor): - loaded_weight = loaded_weight[:] - + loaded_weight = convert_pyslice_to_tensor(loaded_weight) assert param.shape == loaded_weight.shape, ( f"{param_name} shape mismatch between model and checkpoint: " f"{param.shape} != {loaded_weight.shape}")