mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 18:07:28 +08:00
Support S3 Sharded loading with RunAI Model Streamer (#16317)
Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
188b7f9b8c
commit
71ce44047f
@ -1489,6 +1489,7 @@ class LoadFormat(str, enum.Enum):
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
RUNAI_STREAMER = "runai_streamer"
|
||||
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
|
||||
FASTSAFETENSORS = "fastsafetensors"
|
||||
|
||||
|
||||
|
||||
@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
def __init__(self,
|
||||
load_config: LoadConfig,
|
||||
runai_model_streamer: bool = False):
|
||||
super().__init__(load_config)
|
||||
|
||||
self.runai_model_streamer = runai_model_streamer
|
||||
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy())
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
@ -659,7 +663,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]):
|
||||
if os.path.isdir(model_name_or_path):
|
||||
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
@ -678,12 +682,13 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
target_device = torch.device(device_config.device)
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
local_model_path = self._prepare_weights(model_config.model,
|
||||
model_config.revision)
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
local_model_path = model_weights
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
@ -695,40 +700,56 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
filepaths = glob.glob(pattern)
|
||||
|
||||
filepaths = []
|
||||
if is_s3(local_model_path):
|
||||
file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}"
|
||||
filepaths = s3_glob(path=local_model_path,
|
||||
allow_pattern=[file_pattern])
|
||||
else:
|
||||
filepaths = glob.glob(pattern)
|
||||
if not filepaths:
|
||||
# TODO: support un-sharded checkpoints too
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint files '{pattern}', only "
|
||||
f"pre-sharded checkpoints are currently supported!")
|
||||
state_dict = self._filter_subtensors(model.state_dict())
|
||||
for path in filepaths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into "
|
||||
"parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
for key, tensor in self.iterate_over_files(filepaths):
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into "
|
||||
"parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
raise ValueError(
|
||||
f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
return model.eval()
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
if self.runai_model_streamer:
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
for path in paths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
yield key, tensor
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
||||
return RunaiModelStreamerLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
|
||||
return ShardedStateLoader(load_config, runai_model_streamer=True)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user