mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai> Signed-off-by: Peter Schuurman <psch@google.com> Co-authored-by: Omer Dayan (SW-GPU) <omer@run.ai> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import glob
|
|
import os
|
|
import tempfile
|
|
|
|
import huggingface_hub.constants
|
|
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
download_weights_from_hf)
|
|
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
|
|
list_safetensors)
|
|
|
|
|
|
def test_is_runai_obj_uri():
|
|
assert is_runai_obj_uri("gs://some-gcs-bucket/path")
|
|
assert is_runai_obj_uri("s3://some-s3-bucket/path")
|
|
assert not is_runai_obj_uri("nfs://some-nfs-path")
|
|
|
|
|
|
def test_runai_list_safetensors_local():
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
huggingface_hub.constants.HF_HUB_OFFLINE = False
|
|
download_weights_from_hf("openai-community/gpt2",
|
|
allow_patterns=["*.safetensors", "*.json"],
|
|
cache_dir=tmpdir)
|
|
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
|
|
assert len(safetensors) > 0
|
|
parentdir = [
|
|
os.path.dirname(safetensor) for safetensor in safetensors
|
|
][0]
|
|
files = list_safetensors(parentdir)
|
|
assert len(safetensors) == len(files)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_is_runai_obj_uri()
|
|
test_runai_list_safetensors_local()
|