mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
Adds parallel model weight loading for runai_streamer (#21330)
Signed-off-by: bbartels <benjamin@bartels.dev> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
774d0c014b
commit
b194557a6c
3
setup.py
3
setup.py
@ -659,7 +659,8 @@ setup(
|
||||
"bench": ["pandas", "datasets"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||
"runai":
|
||||
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
|
||||
"audio": ["librosa", "soundfile",
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [] # Kept for backwards compatibility
|
||||
|
||||
@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator(
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
yield from streamer.get_tensors()
|
||||
streamer.stream_files(hf_weights_files)
|
||||
total_tensors = sum(
|
||||
len(tensors_meta)
|
||||
for tensors_meta in streamer.files_to_tensors_metadata.values())
|
||||
|
||||
tensor_iter = tqdm(
|
||||
streamer.get_tensors(),
|
||||
total=total_tensors,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
bar_format=_BAR_FORMAT,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
)
|
||||
|
||||
yield from tensor_iter
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user