Set weights_only=True when using torch.load() (#12366)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-01-23 21:17:30 -05:00 committed by GitHub
parent 24b0205f58
commit d3d6bb13fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 6 deletions

View File

@ -26,4 +26,4 @@ class ImageAsset:
""" """
image_path = get_vllm_public_assets(filename=f"{self.name}.pt", image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR) s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path, map_location="cpu") return torch.load(image_path, map_location="cpu", weights_only=True)

View File

@ -273,7 +273,8 @@ class LoRAModel(AdapterModel):
new_embeddings_tensor_path) new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path): elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path, embeddings = torch.load(new_embeddings_bin_file_path,
map_location=device) map_location=device,
weights_only=True)
return cls.from_lora_tensors( return cls.from_lora_tensors(
lora_model_id=get_lora_id() lora_model_id=get_lora_id()

View File

@ -93,7 +93,7 @@ def convert_bin_to_safetensor_file(
pt_filename: str, pt_filename: str,
sf_filename: str, sf_filename: str,
) -> None: ) -> None:
loaded = torch.load(pt_filename, map_location="cpu") loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
shared = _shared_pointers(loaded) shared = _shared_pointers(loaded)
@ -381,7 +381,9 @@ def np_cache_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file,
map_location="cpu",
weights_only=True)
for name, param in state.items(): for name, param in state.items():
param_path = os.path.join(np_folder, name) param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f: with open(param_path, "wb") as f:
@ -447,7 +449,7 @@ def pt_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items() yield from state.items()
del state del state
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -89,6 +89,7 @@ def load_peft_weights(model_id: str,
adapters_weights = safe_load_file(filename, device=device) adapters_weights = safe_load_file(filename, device=device)
else: else:
adapters_weights = torch.load(filename, adapters_weights = torch.load(filename,
map_location=torch.device(device)) map_location=torch.device(device),
weights_only=True)
return adapters_weights return adapters_weights