Add blacklist in model checkpoint (#1325)

This commit is contained in:
Woosuk Kwon 2023-10-12 01:05:37 -07:00 committed by GitHub
parent ee8217e5be
commit 875afe38ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -144,8 +144,18 @@ def prepare_hf_model_weights(
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: