mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, call, patch
|
|
|
|
import pytest
|
|
|
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"allow_patterns,expected_relative_files",
|
|
[
|
|
(
|
|
["*.json", "correct*.txt"],
|
|
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
|
|
),
|
|
],
|
|
)
|
|
def test_list_filtered_repo_files(
|
|
allow_patterns: list[str], expected_relative_files: list[str]
|
|
):
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
# Prep folder and files
|
|
path_tmp_dir = Path(tmp_dir)
|
|
subfolder = path_tmp_dir / "subfolder"
|
|
subfolder.mkdir()
|
|
(path_tmp_dir / "json_file.json").touch()
|
|
(path_tmp_dir / "correct_2.txt").touch()
|
|
(path_tmp_dir / "uncorrect.txt").touch()
|
|
(path_tmp_dir / "uncorrect.jpeg").touch()
|
|
(subfolder / "correct.txt").touch()
|
|
(subfolder / "uncorrect_sub.txt").touch()
|
|
|
|
def _glob_path() -> list[str]:
|
|
return [
|
|
str(file.relative_to(path_tmp_dir))
|
|
for file in path_tmp_dir.glob("**/*")
|
|
if file.is_file()
|
|
]
|
|
|
|
# Patch list_repo_files called by fn
|
|
with patch(
|
|
"vllm.transformers_utils.repo_utils.list_repo_files",
|
|
MagicMock(return_value=_glob_path()),
|
|
) as mock_list_repo_files:
|
|
out_files = sorted(
|
|
list_filtered_repo_files(
|
|
tmp_dir, allow_patterns, "revision", "model", "token"
|
|
)
|
|
)
|
|
assert out_files == sorted(expected_relative_files)
|
|
assert mock_list_repo_files.call_count == 1
|
|
assert mock_list_repo_files.call_args_list[0] == call(
|
|
repo_id=tmp_dir,
|
|
revision="revision",
|
|
repo_type="model",
|
|
token="token",
|
|
)
|