[log] add weights loading time log to sharded_state loader (#28628)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie 2025-11-22 05:06:09 +08:00 committed by GitHub
parent 1840c5cb18
commit 53a1ba6ec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@
import collections
import glob
import os
import time
from collections.abc import Generator
from typing import Any
@ -132,6 +133,7 @@ class ShardedStateLoader(BaseModelLoader):
f"pre-sharded checkpoints are currently supported!"
)
state_dict = self._filter_subtensors(model.state_dict())
counter_before_loading_weights = time.perf_counter()
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
@ -150,6 +152,12 @@ class ShardedStateLoader(BaseModelLoader):
)
param_data.copy_(tensor)
state_dict.pop(key)
counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
counter_after_loading_weights - counter_before_loading_weights,
scope="local",
)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")