Fix typing for safetensors_load_strategy (#24641)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-11 11:41:39 +01:00 committed by GitHub
parent 25bb9e8c65
commit d6249d0699
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 4 deletions

View File

@ -51,7 +51,7 @@ class LoadConfig:
download_dir: Optional[str] = None download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default """Directory to download and load the weights, default to the default
cache directory of Hugging Face.""" cache directory of Hugging Face."""
safetensors_load_strategy: Optional[str] = "lazy" safetensors_load_strategy: str = "lazy"
"""Specifies the loading strategy for safetensors weights. """Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables - "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage. on-demand loading and is highly efficient for models on local storage.

View File

@ -289,8 +289,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: Optional[ safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
str] = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype dtype: ModelDType = ModelConfig.dtype

View File

@ -519,7 +519,7 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: list[str], hf_weights_files: list[str],
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
safetensors_load_strategy: Optional[str] = "lazy", safetensors_load_strategy: str = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
loading_desc = "Loading safetensors checkpoint shards" loading_desc = "Loading safetensors checkpoint shards"