Adds parallel model weight loading for runai_streamer (#21330)

Signed-off-by: bbartels <benjamin@bartels.dev>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Benjamin Bartels 2025-07-22 16:15:53 +01:00 committed by GitHub
parent 774d0c014b
commit b194557a6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 9 deletions

View File

@ -659,7 +659,8 @@ setup(
"bench": ["pandas", "datasets"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"runai":
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing
"video": [] # Kept for backwards compatibility

View File

@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator(
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()
streamer.stream_files(hf_weights_files)
total_tensors = sum(
len(tensors_meta)
for tensors_meta in streamer.files_to_tensors_metadata.values())
tensor_iter = tqdm(
streamer.get_tensors(),
total=total_tensors,
desc="Loading safetensors using Runai Model Streamer",
bar_format=_BAR_FORMAT,
disable=not enable_tqdm(use_tqdm_on_load),
)
yield from tensor_iter
def fastsafetensors_weights_iterator(