[Bugfix] Update Run:AI Model Streamer Loading Integration (#23845)

Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai>
Signed-off-by: Peter Schuurman <psch@google.com>
Co-authored-by: Omer Dayan (SW-GPU) <omer@run.ai>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
pwschuurman 2025-09-09 21:37:17 -07:00 committed by GitHub
parent 009d689b0c
commit 4377b1ae3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 187 additions and 122 deletions

View File

@ -656,8 +656,10 @@ setup(
"bench": ["pandas", "datasets"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai":
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
"runai": [
"runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs",
"google-cloud-storage", "runai-model-streamer-s3", "boto3"
],
"audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing
"video": [], # Kept for backwards compatibility

View File

@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import os
import tempfile
import huggingface_hub.constants
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf)
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
list_safetensors)
def test_is_runai_obj_uri():
assert is_runai_obj_uri("gs://some-gcs-bucket/path")
assert is_runai_obj_uri("s3://some-s3-bucket/path")
assert not is_runai_obj_uri("nfs://some-nfs-path")
def test_runai_list_safetensors_local():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors", "*.json"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0
parentdir = [
os.path.dirname(safetensor) for safetensor in safetensors
][0]
files = list_safetensors(parentdir)
assert len(safetensors) == len(files)
if __name__ == "__main__":
test_is_runai_obj_uri()
test_runai_list_safetensors_local()

View File

@ -48,8 +48,9 @@ from vllm.transformers_utils.config import (
is_interleaved, maybe_override_with_speculators_target_model,
try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
is_runai_obj_uri)
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType,
LazyLoader, common_broadcastable_dtype, random_uuid)
@ -556,15 +557,6 @@ class ModelConfig:
"affect the random state of the Python process that "
"launched vLLM.", self.seed)
if self.runner != "draft":
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code)
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
@ -603,7 +595,16 @@ class ModelConfig:
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if self.runner != "draft":
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code)
if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
@ -832,41 +833,42 @@ class ModelConfig:
"""The architecture vllm actually used."""
return self._architecture
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""Pull model/tokenizer from S3 to temporary directory when needed.
def maybe_pull_model_tokenizer_for_runai(self, model: str,
tokenizer: str) -> None:
"""Pull model/tokenizer from Object Storage to temporary
directory when needed.
Args:
model: Model name or path
tokenizer: Tokenizer name or path
"""
if not (is_s3(model) or is_s3(tokenizer)):
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
return
if is_s3(model):
s3_model = S3Model()
s3_model.pull_files(model,
allow_pattern=["*.model", "*.py", "*.json"])
if is_runai_obj_uri(model):
object_storage_model = ObjectStorageModel()
object_storage_model.pull_files(
model, allow_pattern=["*.model", "*.py", "*.json"])
self.model_weights = model
self.model = s3_model.dir
self.model = object_storage_model.dir
# If tokenizer is same as model, download to same directory
if model == tokenizer:
s3_model.pull_files(model,
ignore_pattern=[
"*.pt", "*.safetensors", "*.bin",
"*.tensors"
])
self.tokenizer = s3_model.dir
object_storage_model.pull_files(model,
ignore_pattern=[
"*.pt", "*.safetensors",
"*.bin", "*.tensors"
])
self.tokenizer = object_storage_model.dir
return
# Only download tokenizer if needed and not already handled
if is_s3(tokenizer):
s3_tokenizer = S3Model()
s3_tokenizer.pull_files(
if is_runai_obj_uri(tokenizer):
object_storage_tokenizer = ObjectStorageModel()
object_storage_tokenizer.pull_files(
model,
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
self.tokenizer = s3_tokenizer.dir
self.tokenizer = object_storage_tokenizer.dir
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self._model_info.supports_multimodal:

View File

@ -1053,9 +1053,10 @@ class EngineArgs:
SpeculatorsConfig)
if self.speculative_config is None:
hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)
hf_config = get_config(
self.hf_config_path or target_model_config.model,
self.trust_remote_code, self.revision, self.code_revision,
self.config_format)
# if loading a SpeculatorsConfig, load the speculative_config
# details from the config directly
@ -1065,7 +1066,7 @@ class EngineArgs:
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = self.model
self.speculative_config["model"] = target_model_config.model
self.speculative_config["method"] = hf_config.method
else:
return None

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import glob
import os
from collections.abc import Generator
from typing import Optional
@ -15,8 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
runai_safetensors_weights_iterator)
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
list_safetensors)
class RunaiModelStreamerLoader(BaseModelLoader):
@ -53,27 +52,22 @@ class RunaiModelStreamerLoader(BaseModelLoader):
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
hf_folder = (model_name_or_path if (is_local or is_object_storage_path)
else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))
hf_weights_files = list_safetensors(path=hf_folder)
if not is_local and not is_s3_path:
if not is_local and not is_object_storage_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)

View File

@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import shutil
import signal
import tempfile
from typing import Optional
from vllm.logger import init_logger
from vllm.utils import PlaceholderModule
logger = init_logger(__name__)
SUPPORTED_SCHEMES = ['s3://', 'gs://']
try:
from runai_model_streamer import list_safetensors as runai_list_safetensors
from runai_model_streamer import pull_files as runai_pull_files
except (ImportError, OSError):
# see https://github.com/run-ai/runai-model-streamer/issues/26
# OSError will be raised on arm64 platform
runai_model_streamer = PlaceholderModule(
"runai_model_streamer") # type: ignore[assignment]
runai_pull_files = runai_model_streamer.placeholder_attr("pull_files")
runai_list_safetensors = runai_model_streamer.placeholder_attr(
"list_safetensors")
def list_safetensors(path: str = "") -> list[str]:
"""
List full file names from object path and filter by allow pattern.
Args:
path: The object storage path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full object storage paths allowed by the pattern
"""
return runai_list_safetensors(path)
def is_runai_obj_uri(model_or_path: str) -> bool:
return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES))
class ObjectStorageModel:
"""
A class representing an ObjectStorage model mirrored into a
temporary directory.
Attributes:
dir: The temporary created directory.
Methods:
pull_files(): Pull model from object storage to the temporary
directory.
"""
def __init__(self) -> None:
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()
def __del__(self):
self._close()
def _close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self._close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
def pull_files(self,
model_path: str = "",
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from object storage into the temporary directory.
Args:
model_path: The object storage path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
if not model_path.endswith("/"):
model_path = model_path + "/"
runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern)

View File

@ -2,11 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path
from typing import Optional
from vllm.utils import PlaceholderModule
@ -93,70 +88,3 @@ def list_files(
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Model:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""
def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()
def __del__(self):
self._close()
def _close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self._close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
def pull_files(self,
s3_model_path: str = "",
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
if not s3_model_path.endswith("/"):
s3_model_path = s3_model_path + "/"
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
if len(files) == 0:
return
for file in files:
destination_file = os.path.join(
self.dir,
file.removeprefix(base_dir).lstrip("/"))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)