[Misc] Accurately capture the time of loading weights (#14063)

Signed-off-by: Jun Duan <jun.duan.phd@outlook.com>
This commit is contained in:
Jun Duan 2025-03-01 20:20:30 -05:00 committed by GitHub
parent cc5e8f6db8
commit 82fbeae92b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 2 deletions

View File

@ -10,6 +10,7 @@ import inspect
import itertools import itertools
import math import math
import os import os
import time
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
@ -216,6 +217,9 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns_overrides: Optional[list[str]] = None allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns.""" """If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
if load_config.model_loader_extra_config: if load_config.model_loader_extra_config:
@ -368,6 +372,8 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix. # Apply the prefix.
return ((source.prefix + name, tensor) return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator) for (name, tensor) in weights_iterator)
@ -412,6 +418,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()} weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights( loaded_weights = model.load_weights(
self._get_all_weights(model_config, model)) self._get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models # We only enable strict check for non-quantized models
# that have loaded weights tracking currently. # that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None: if model_config.quantization is None and loaded_weights is not None:

View File

@ -1061,7 +1061,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.device) self.device)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB and %.6f seconds", logger.info("Model loading took %.4f GB and %.6f seconds",
self.model_memory_usage / float(2**30), self.model_memory_usage / float(2**30),
time_after_load - time_before_load) time_after_load - time_before_load)

View File

@ -1114,7 +1114,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB and %.6f seconds", logger.info("Model loading took %.4f GB and %.6f seconds",
self.model_memory_usage / float(2**30), self.model_memory_usage / float(2**30),
time_after_load - time_before_load) time_after_load - time_before_load)