Support S3 Sharded loading with RunAI Model Streamer (#16317)

Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
omer-dayan 2025-04-22 07:21:49 +03:00 committed by GitHub
parent 188b7f9b8c
commit 71ce44047f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 28 deletions

View File

@ -1489,6 +1489,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes" BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral" MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer" RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors" FASTSAFETENSORS = "fastsafetensors"

View File

@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig): def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config) super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()) else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
@ -659,7 +663,7 @@ class ShardedStateLoader(BaseModelLoader):
def _prepare_weights(self, model_name_or_path: str, def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]): revision: Optional[str]):
if os.path.isdir(model_name_or_path): if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path return model_name_or_path
else: else:
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
@ -678,12 +682,13 @@ class ShardedStateLoader(BaseModelLoader):
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(model_config.model, model_weights = model_config.model
model_config.revision) if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
@ -695,40 +700,56 @@ class ShardedStateLoader(BaseModelLoader):
local_model_path, local_model_path,
self.pattern.format(rank=rank, part="*"), self.pattern.format(rank=rank, part="*"),
) )
filepaths = glob.glob(pattern)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths: if not filepaths:
# TODO: support un-sharded checkpoints too # TODO: support un-sharded checkpoints too
raise ValueError( raise ValueError(
f"Could not find checkpoint files '{pattern}', only " f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!") f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict()) state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths: for key, tensor in self.iterate_over_files(filepaths):
with safe_open(path, framework="pt") as f: # If loading with LoRA enabled, additional padding may
for key in f.keys(): # noqa: SIM118 # be added to certain parameters. We only load into a
tensor = f.get_tensor(key) # narrowed view of the parameter data.
# If loading with LoRA enabled, additional padding may param_data = state_dict[key].data
# be added to certain parameters. We only load into a param_shape = state_dict[key].shape
# narrowed view of the parameter data. for dim, size in enumerate(tensor.shape):
param_data = state_dict[key].data if size < param_shape[dim]:
param_shape = state_dict[key].shape param_data = param_data.narrow(dim, 0, size)
for dim, size in enumerate(tensor.shape): if tensor.shape != param_shape:
if size < param_shape[dim]: logger.warning(
param_data = param_data.narrow(dim, 0, size) "loading tensor of shape %s into "
if tensor.shape != param_shape: "parameter '%s' of shape %s",
logger.warning( tensor.shape,
"loading tensor of shape %s into " key,
"parameter '%s' of shape %s", param_shape,
tensor.shape, )
key, param_data.copy_(tensor)
param_shape, state_dict.pop(key)
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict: if state_dict:
raise ValueError( raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!") f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval() return model.eval()
def iterate_over_files(
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod @staticmethod
def save_model( def save_model(
model: torch.nn.Module, model: torch.nn.Module,
@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.RUNAI_STREAMER: if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config) return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config) return DefaultModelLoader(load_config)