mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:15:01 +08:00
Fix typing for safetensors_load_strategy (#24641)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
25bb9e8c65
commit
d6249d0699
@ -51,7 +51,7 @@ class LoadConfig:
|
|||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
"""Directory to download and load the weights, default to the default
|
"""Directory to download and load the weights, default to the default
|
||||||
cache directory of Hugging Face."""
|
cache directory of Hugging Face."""
|
||||||
safetensors_load_strategy: Optional[str] = "lazy"
|
safetensors_load_strategy: str = "lazy"
|
||||||
"""Specifies the loading strategy for safetensors weights.
|
"""Specifies the loading strategy for safetensors weights.
|
||||||
- "lazy" (default): Weights are memory-mapped from the file. This enables
|
- "lazy" (default): Weights are memory-mapped from the file. This enables
|
||||||
on-demand loading and is highly efficient for models on local storage.
|
on-demand loading and is highly efficient for models on local storage.
|
||||||
|
|||||||
@ -289,8 +289,7 @@ class EngineArgs:
|
|||||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||||
download_dir: Optional[str] = LoadConfig.download_dir
|
download_dir: Optional[str] = LoadConfig.download_dir
|
||||||
safetensors_load_strategy: Optional[
|
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
||||||
str] = LoadConfig.safetensors_load_strategy
|
|
||||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||||
config_format: str = ModelConfig.config_format
|
config_format: str = ModelConfig.config_format
|
||||||
dtype: ModelDType = ModelConfig.dtype
|
dtype: ModelDType = ModelConfig.dtype
|
||||||
|
|||||||
@ -519,7 +519,7 @@ def np_cache_weights_iterator(
|
|||||||
def safetensors_weights_iterator(
|
def safetensors_weights_iterator(
|
||||||
hf_weights_files: list[str],
|
hf_weights_files: list[str],
|
||||||
use_tqdm_on_load: bool,
|
use_tqdm_on_load: bool,
|
||||||
safetensors_load_strategy: Optional[str] = "lazy",
|
safetensors_load_strategy: str = "lazy",
|
||||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||||
"""Iterate over the weights in the model safetensor files."""
|
"""Iterate over the weights in the model safetensor files."""
|
||||||
loading_desc = "Loading safetensors checkpoint shards"
|
loading_desc = "Loading safetensors checkpoint shards"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user