Simplify weight loading logic (#2133)

This commit is contained in:
Roy 2023-12-17 04:41:23 +08:00 committed by GitHub
parent 2acd76f346
commit eed74a558f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 37 deletions

View File

@ -122,15 +122,10 @@ class ModelConfig:
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures:
if load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"
if "MixtralForCausalLM" in architectures and load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:

View File

@ -412,7 +412,11 @@ class MixtralForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:

View File

@ -125,15 +125,29 @@ def get_quant_config(
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensors: bool = False,
load_format: str = "auto",
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
allow_patterns = ["*.safetensors"
] if use_safetensors else ["*.bin", "*.pt"]
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += [".pt"]
if not is_local:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
@ -148,6 +162,10 @@ def prepare_hf_model_weights(
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
@ -163,13 +181,6 @@ def prepare_hf_model_weights(
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensors=False,
fall_back_to_pt=False,
revision=revision)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
@ -182,30 +193,16 @@ def hf_model_weights_iterator(
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False
use_np_cache = False
fall_back_to_pt = False
if load_format == "auto":
use_safetensors = True
fall_back_to_pt = True
elif load_format == "safetensors":
use_safetensors = True
elif load_format == "pt":
pass
elif load_format == "npcache":
use_np_cache = True
else:
raise ValueError(f"Unknown load_format: {load_format}")
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
model_name_or_path,
cache_dir=cache_dir,
use_safetensors=use_safetensors,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision)
if use_np_cache:
if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False