mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 06:06:33 +08:00
Add blacklist in model checkpoint (#1325)
This commit is contained in:
parent
ee8217e5be
commit
875afe38ab
@ -144,8 +144,18 @@ def prepare_hf_model_weights(
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user