mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:45:49 +08:00
[Bugfix] Add file lock for ModelScope download (#14060)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
f64ffa8c25
commit
6a84164add
@ -14,6 +14,8 @@ from tqdm.asyncio import tqdm
|
|||||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||||
|
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||||
|
|
||||||
|
|
||||||
@ -430,12 +432,15 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
|||||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
model_path = snapshot_download(
|
# Use file lock to prevent multiple processes from
|
||||||
model_id=pretrained_model_name_or_path,
|
# downloading the same model weights at the same time.
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
with get_lock(pretrained_model_name_or_path):
|
||||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
model_path = snapshot_download(
|
||||||
|
model_id=pretrained_model_name_or_path,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
||||||
|
|
||||||
return model_path
|
return model_path
|
||||||
return pretrained_model_name_or_path
|
return pretrained_model_name_or_path
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
|
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
|
||||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -235,13 +235,17 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
|
||||||
if not os.path.exists(model):
|
if not os.path.exists(model):
|
||||||
model_path = snapshot_download(
|
# Use file lock to prevent multiple processes from
|
||||||
model_id=model,
|
# downloading the same model weights at the same time.
|
||||||
cache_dir=self.load_config.download_dir,
|
with get_lock(model, self.load_config.download_dir):
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
model_path = snapshot_download(
|
||||||
revision=revision,
|
model_id=model,
|
||||||
ignore_file_pattern=self.load_config.ignore_patterns,
|
cache_dir=self.load_config.download_dir,
|
||||||
)
|
local_files_only=huggingface_hub.constants.
|
||||||
|
HF_HUB_OFFLINE,
|
||||||
|
revision=revision,
|
||||||
|
ignore_file_pattern=self.load_config.ignore_patterns,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_path = model
|
model_path = model
|
||||||
return model_path
|
return model_path
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
@ -67,8 +68,10 @@ class DisabledTqdm(tqdm):
|
|||||||
super().__init__(*args, **kwargs, disable=True)
|
super().__init__(*args, **kwargs, disable=True)
|
||||||
|
|
||||||
|
|
||||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
def get_lock(model_name_or_path: Union[str, Path],
|
||||||
|
cache_dir: Optional[str] = None):
|
||||||
lock_dir = cache_dir or temp_dir
|
lock_dir = cache_dir or temp_dir
|
||||||
|
model_name_or_path = str(model_name_or_path)
|
||||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||||
model_name = model_name_or_path.replace("/", "-")
|
model_name = model_name_or_path.replace("/", "-")
|
||||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||||
|
|||||||
@ -150,16 +150,22 @@ def get_tokenizer(
|
|||||||
# pylint: disable=C.
|
# pylint: disable=C.
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
|
||||||
|
# avoid circuit import
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||||
|
|
||||||
# Only set the tokenizer here, model will be downloaded on the workers.
|
# Only set the tokenizer here, model will be downloaded on the workers.
|
||||||
if not os.path.exists(tokenizer_name):
|
if not os.path.exists(tokenizer_name):
|
||||||
tokenizer_path = snapshot_download(
|
# Use file lock to prevent multiple processes from
|
||||||
model_id=tokenizer_name,
|
# downloading the same file at the same time.
|
||||||
cache_dir=download_dir,
|
with get_lock(tokenizer_name, download_dir):
|
||||||
revision=revision,
|
tokenizer_path = snapshot_download(
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
model_id=tokenizer_name,
|
||||||
# Ignore weights - we only need the tokenizer.
|
cache_dir=download_dir,
|
||||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
revision=revision,
|
||||||
tokenizer_name = tokenizer_path
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
# Ignore weights - we only need the tokenizer.
|
||||||
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
||||||
|
tokenizer_name = tokenizer_path
|
||||||
|
|
||||||
if tokenizer_mode == "slow":
|
if tokenizer_mode == "slow":
|
||||||
if kwargs.get("use_fast", False):
|
if kwargs.get("use_fast", False):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user