From b194557a6cfdd9eab777234c2ab3d90907e1c8f3 Mon Sep 17 00:00:00 2001 From: Benjamin Bartels Date: Tue, 22 Jul 2025 16:15:53 +0100 Subject: [PATCH] Adds parallel model weight loading for runai_streamer (#21330) Signed-off-by: bbartels Co-authored-by: Cyrus Leung --- setup.py | 3 ++- .../model_loader/weight_utils.py | 22 ++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 9a5ca3456a0fa..d46e678e7aa40 100644 --- a/setup.py +++ b/setup.py @@ -659,7 +659,8 @@ setup( "bench": ["pandas", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], + "runai": + ["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"], "audio": ["librosa", "soundfile", "mistral_common[audio]"], # Required for audio processing "video": [] # Kept for backwards compatibility diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 64a2089921eea..074126fa669e9 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator( ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" with SafetensorsStreamer() as streamer: - for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors using Runai Model Streamer", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, - ): - streamer.stream_file(st_file) - yield from streamer.get_tensors() + streamer.stream_files(hf_weights_files) + total_tensors = sum( + len(tensors_meta) + for tensors_meta in streamer.files_to_tensors_metadata.values()) + + tensor_iter = tqdm( + streamer.get_tensors(), + total=total_tensors, + desc="Loading safetensors using Runai Model Streamer", + bar_format=_BAR_FORMAT, + disable=not enable_tqdm(use_tqdm_on_load), + ) + + yield from tensor_iter def fastsafetensors_weights_iterator(