[Misc] Capture and log the time of loading weights (#13666)

This commit is contained in:
Jun Duan 2025-02-22 01:06:34 -05:00 committed by GitHub
parent c6ed93860f
commit 68d535ef44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 5 deletions

View File

@ -1048,6 +1048,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter()
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model,
@ -1055,10 +1056,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.scheduler_config,
self.lora_config,
self.device)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
logger.info("Loading model weights took %.4f GB and %.6f seconds",
self.model_memory_usage / float(2**30),
time_after_load - time_before_load)
def _get_prompt_logprobs_dict(
self,

View File

@ -1109,11 +1109,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler(self.device) as m:
time_before_load = time.perf_counter()
self.model = get_model(vllm_config=self.vllm_config)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
logger.info("Loading model weights took %.4f GB and %.6f seconds",
self.model_memory_usage / float(2**30),
time_after_load - time_before_load)
if self.lora_config:
assert supports_lora(