mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 18:35:42 +08:00
Support S3 Sharded loading with RunAI Model Streamer (#16317)
Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
188b7f9b8c
commit
71ce44047f
@ -1489,6 +1489,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
BITSANDBYTES = "bitsandbytes"
|
BITSANDBYTES = "bitsandbytes"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
RUNAI_STREAMER = "runai_streamer"
|
RUNAI_STREAMER = "runai_streamer"
|
||||||
|
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
|
||||||
FASTSAFETENSORS = "fastsafetensors"
|
FASTSAFETENSORS = "fastsafetensors"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
|
|
||||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
runai_model_streamer: bool = False):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
|
|
||||||
|
self.runai_model_streamer = runai_model_streamer
|
||||||
extra_config = ({} if load_config.model_loader_extra_config is None
|
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||||
else load_config.model_loader_extra_config.copy())
|
else load_config.model_loader_extra_config.copy())
|
||||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||||
@ -659,7 +663,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def _prepare_weights(self, model_name_or_path: str,
|
def _prepare_weights(self, model_name_or_path: str,
|
||||||
revision: Optional[str]):
|
revision: Optional[str]):
|
||||||
if os.path.isdir(model_name_or_path):
|
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||||
return model_name_or_path
|
return model_name_or_path
|
||||||
else:
|
else:
|
||||||
allow_patterns = ["*.safetensors"]
|
allow_patterns = ["*.safetensors"]
|
||||||
@ -678,12 +682,13 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
from safetensors.torch import safe_open
|
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
local_model_path = self._prepare_weights(model_config.model,
|
model_weights = model_config.model
|
||||||
model_config.revision)
|
if hasattr(model_config, "model_weights"):
|
||||||
|
model_weights = model_config.model_weights
|
||||||
|
local_model_path = model_weights
|
||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
@ -695,40 +700,56 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
local_model_path,
|
local_model_path,
|
||||||
self.pattern.format(rank=rank, part="*"),
|
self.pattern.format(rank=rank, part="*"),
|
||||||
)
|
)
|
||||||
filepaths = glob.glob(pattern)
|
|
||||||
|
filepaths = []
|
||||||
|
if is_s3(local_model_path):
|
||||||
|
file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}"
|
||||||
|
filepaths = s3_glob(path=local_model_path,
|
||||||
|
allow_pattern=[file_pattern])
|
||||||
|
else:
|
||||||
|
filepaths = glob.glob(pattern)
|
||||||
if not filepaths:
|
if not filepaths:
|
||||||
# TODO: support un-sharded checkpoints too
|
# TODO: support un-sharded checkpoints too
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not find checkpoint files '{pattern}', only "
|
f"Could not find checkpoint files '{pattern}', only "
|
||||||
f"pre-sharded checkpoints are currently supported!")
|
f"pre-sharded checkpoints are currently supported!")
|
||||||
state_dict = self._filter_subtensors(model.state_dict())
|
state_dict = self._filter_subtensors(model.state_dict())
|
||||||
for path in filepaths:
|
for key, tensor in self.iterate_over_files(filepaths):
|
||||||
with safe_open(path, framework="pt") as f:
|
# If loading with LoRA enabled, additional padding may
|
||||||
for key in f.keys(): # noqa: SIM118
|
# be added to certain parameters. We only load into a
|
||||||
tensor = f.get_tensor(key)
|
# narrowed view of the parameter data.
|
||||||
# If loading with LoRA enabled, additional padding may
|
param_data = state_dict[key].data
|
||||||
# be added to certain parameters. We only load into a
|
param_shape = state_dict[key].shape
|
||||||
# narrowed view of the parameter data.
|
for dim, size in enumerate(tensor.shape):
|
||||||
param_data = state_dict[key].data
|
if size < param_shape[dim]:
|
||||||
param_shape = state_dict[key].shape
|
param_data = param_data.narrow(dim, 0, size)
|
||||||
for dim, size in enumerate(tensor.shape):
|
if tensor.shape != param_shape:
|
||||||
if size < param_shape[dim]:
|
logger.warning(
|
||||||
param_data = param_data.narrow(dim, 0, size)
|
"loading tensor of shape %s into "
|
||||||
if tensor.shape != param_shape:
|
"parameter '%s' of shape %s",
|
||||||
logger.warning(
|
tensor.shape,
|
||||||
"loading tensor of shape %s into "
|
key,
|
||||||
"parameter '%s' of shape %s",
|
param_shape,
|
||||||
tensor.shape,
|
)
|
||||||
key,
|
param_data.copy_(tensor)
|
||||||
param_shape,
|
state_dict.pop(key)
|
||||||
)
|
|
||||||
param_data.copy_(tensor)
|
|
||||||
state_dict.pop(key)
|
|
||||||
if state_dict:
|
if state_dict:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Missing keys {tuple(state_dict)} in loaded state!")
|
f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
def iterate_over_files(
|
||||||
|
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
if self.runai_model_streamer:
|
||||||
|
yield from runai_safetensors_weights_iterator(paths, True)
|
||||||
|
else:
|
||||||
|
from safetensors.torch import safe_open
|
||||||
|
for path in paths:
|
||||||
|
with safe_open(path, framework="pt") as f:
|
||||||
|
for key in f.keys(): # noqa: SIM118
|
||||||
|
tensor = f.get_tensor(key)
|
||||||
|
yield key, tensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_model(
|
def save_model(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|||||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
||||||
return RunaiModelStreamerLoader(load_config)
|
return RunaiModelStreamerLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
|
||||||
|
return ShardedStateLoader(load_config, runai_model_streamer=True)
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user