From 71ce44047f20478f9c61d96907fdc2dac89e7e0a Mon Sep 17 00:00:00 2001 From: omer-dayan Date: Tue, 22 Apr 2025 07:21:49 +0300 Subject: [PATCH] Support S3 Sharded loading with RunAI Model Streamer (#16317) Signed-off-by: Omer Dayan (SW-GPU) Co-authored-by: Cyrus Leung --- vllm/config.py | 1 + vllm/model_executor/model_loader/loader.py | 80 ++++++++++++++-------- 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a3ed94bc50f82..20ca20ad2b6d5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1489,6 +1489,7 @@ class LoadFormat(str, enum.Enum): BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" RUNAI_STREAMER = "runai_streamer" + RUNAI_STREAMER_SHARDED = "runai_streamer_sharded" FASTSAFETENSORS = "fastsafetensors" diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b0a0a20aa76f0..ae5662a9b48a9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader): DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - def __init__(self, load_config: LoadConfig): + def __init__(self, + load_config: LoadConfig, + runai_model_streamer: bool = False): super().__init__(load_config) + + self.runai_model_streamer = runai_model_streamer extra_config = ({} if load_config.model_loader_extra_config is None else load_config.model_loader_extra_config.copy()) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) @@ -659,7 +663,7 @@ class ShardedStateLoader(BaseModelLoader): def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): - if os.path.isdir(model_name_or_path): + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): return model_name_or_path else: allow_patterns = ["*.safetensors"] @@ -678,12 +682,13 @@ class ShardedStateLoader(BaseModelLoader): device_config = vllm_config.device_config model_config = vllm_config.model_config target_device = torch.device(device_config.device) - from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank - local_model_path = self._prepare_weights(model_config.model, - model_config.revision) + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + local_model_path = model_weights with set_default_torch_dtype(model_config.dtype): with target_device: @@ -695,40 +700,56 @@ class ShardedStateLoader(BaseModelLoader): local_model_path, self.pattern.format(rank=rank, part="*"), ) - filepaths = glob.glob(pattern) + + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) if not filepaths: # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " f"pre-sharded checkpoints are currently supported!") state_dict = self._filter_subtensors(model.state_dict()) - for path in filepaths: - with safe_open(path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - tensor = f.get_tensor(key) - # If loading with LoRA enabled, additional padding may - # be added to certain parameters. We only load into a - # narrowed view of the parameter data. - param_data = state_dict[key].data - param_shape = state_dict[key].shape - for dim, size in enumerate(tensor.shape): - if size < param_shape[dim]: - param_data = param_data.narrow(dim, 0, size) - if tensor.shape != param_shape: - logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, - key, - param_shape, - ) - param_data.copy_(tensor) - state_dict.pop(key) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) if state_dict: raise ValueError( f"Missing keys {tuple(state_dict)} in loaded state!") return model.eval() + def iterate_over_files( + self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: + if self.runai_model_streamer: + yield from runai_safetensors_weights_iterator(paths, True) + else: + from safetensors.torch import safe_open + for path in paths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + yield key, tensor + @staticmethod def save_model( model: torch.nn.Module, @@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.RUNAI_STREAMER: return RunaiModelStreamerLoader(load_config) + if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: + return ShardedStateLoader(load_config, runai_model_streamer=True) + return DefaultModelLoader(load_config)