mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 02:45:19 +08:00
[Feature]: Improve GGUF loading from HuggingFace user experience like repo_id:quant_type (#29137)
Signed-off-by: Injae Ryou <injaeryou@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
0231ce836a
commit
794029f012
240
tests/models/test_gguf_download.py
Normal file
240
tests/models/test_gguf_download.py
Normal file
@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import download_gguf
|
||||
|
||||
|
||||
class TestGGUFDownload:
|
||||
"""Test GGUF model downloading functionality."""
|
||||
|
||||
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||
def test_download_gguf_single_file(self, mock_download):
|
||||
"""Test downloading a single GGUF file."""
|
||||
# Setup mock
|
||||
mock_folder = "/tmp/mock_cache"
|
||||
mock_download.return_value = mock_folder
|
||||
|
||||
# Mock glob to return a single file
|
||||
with patch("glob.glob") as mock_glob:
|
||||
mock_glob.side_effect = lambda pattern, **kwargs: (
|
||||
[f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
|
||||
)
|
||||
|
||||
result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
|
||||
|
||||
# Verify download_weights_from_hf was called with correct patterns
|
||||
mock_download.assert_called_once_with(
|
||||
model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
|
||||
cache_dir=None,
|
||||
allow_patterns=[
|
||||
"*-IQ1_S.gguf",
|
||||
"*-IQ1_S-*.gguf",
|
||||
"*/*-IQ1_S.gguf",
|
||||
"*/*-IQ1_S-*.gguf",
|
||||
],
|
||||
revision=None,
|
||||
ignore_patterns=None,
|
||||
)
|
||||
|
||||
# Verify result is the file path, not folder
|
||||
assert result == f"{mock_folder}/model-IQ1_S.gguf"
|
||||
|
||||
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||
def test_download_gguf_sharded_files(self, mock_download):
|
||||
"""Test downloading sharded GGUF files."""
|
||||
mock_folder = "/tmp/mock_cache"
|
||||
mock_download.return_value = mock_folder
|
||||
|
||||
# Mock glob to return sharded files
|
||||
with patch("glob.glob") as mock_glob:
|
||||
mock_glob.side_effect = lambda pattern, **kwargs: (
|
||||
[
|
||||
f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
|
||||
f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
|
||||
]
|
||||
if "Q2_K" in pattern
|
||||
else []
|
||||
)
|
||||
|
||||
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
|
||||
|
||||
# Should return the first file after sorting
|
||||
assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
|
||||
|
||||
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||
def test_download_gguf_subdir(self, mock_download):
|
||||
"""Test downloading GGUF files from subdirectory."""
|
||||
mock_folder = "/tmp/mock_cache"
|
||||
mock_download.return_value = mock_folder
|
||||
|
||||
with patch("glob.glob") as mock_glob:
|
||||
mock_glob.side_effect = lambda pattern, **kwargs: (
|
||||
[f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
|
||||
if "Q2_K" in pattern or "**/*.gguf" in pattern
|
||||
else []
|
||||
)
|
||||
|
||||
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
|
||||
|
||||
assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
|
||||
|
||||
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||
@patch("glob.glob", return_value=[])
|
||||
def test_download_gguf_no_files_found(self, mock_glob, mock_download):
|
||||
"""Test error when no GGUF files are found."""
|
||||
mock_folder = "/tmp/mock_cache"
|
||||
mock_download.return_value = mock_folder
|
||||
|
||||
with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
|
||||
download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
|
||||
|
||||
|
||||
class TestGGUFModelLoader:
|
||||
"""Test GGUFModelLoader class methods."""
|
||||
|
||||
@patch("os.path.isfile", return_value=True)
|
||||
def test_prepare_weights_local_file(self, mock_isfile):
|
||||
"""Test _prepare_weights with local file."""
|
||||
load_config = LoadConfig(load_format="gguf")
|
||||
loader = GGUFModelLoader(load_config)
|
||||
|
||||
# Create a simple mock ModelConfig with only the model attribute
|
||||
model_config = MagicMock()
|
||||
model_config.model = "/path/to/model.gguf"
|
||||
|
||||
result = loader._prepare_weights(model_config)
|
||||
assert result == "/path/to/model.gguf"
|
||||
mock_isfile.assert_called_once_with("/path/to/model.gguf")
|
||||
|
||||
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
|
||||
"""Test _prepare_weights with HTTPS URL."""
|
||||
load_config = LoadConfig(load_format="gguf")
|
||||
loader = GGUFModelLoader(load_config)
|
||||
|
||||
mock_hf_download.return_value = "/downloaded/model.gguf"
|
||||
|
||||
# Create a simple mock ModelConfig with only the model attribute
|
||||
model_config = MagicMock()
|
||||
model_config.model = "https://huggingface.co/model.gguf"
|
||||
|
||||
result = loader._prepare_weights(model_config)
|
||||
assert result == "/downloaded/model.gguf"
|
||||
mock_hf_download.assert_called_once_with(
|
||||
url="https://huggingface.co/model.gguf"
|
||||
)
|
||||
|
||||
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
|
||||
"""Test _prepare_weights with repo_id/filename.gguf format."""
|
||||
load_config = LoadConfig(load_format="gguf")
|
||||
loader = GGUFModelLoader(load_config)
|
||||
|
||||
mock_hf_download.return_value = "/downloaded/model.gguf"
|
||||
|
||||
# Create a simple mock ModelConfig with only the model attribute
|
||||
model_config = MagicMock()
|
||||
model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
|
||||
|
||||
result = loader._prepare_weights(model_config)
|
||||
assert result == "/downloaded/model.gguf"
|
||||
mock_hf_download.assert_called_once_with(
|
||||
repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
|
||||
)
|
||||
|
||||
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
|
||||
@patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
|
||||
@patch("vllm.config.model.get_config")
|
||||
@patch("vllm.config.model.is_gguf", return_value=True)
|
||||
@patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_prepare_weights_repo_quant_type(
|
||||
self,
|
||||
mock_isfile,
|
||||
mock_download_gguf,
|
||||
mock_is_gguf,
|
||||
mock_get_config,
|
||||
mock_file_exists,
|
||||
mock_get_image_config,
|
||||
):
|
||||
"""Test _prepare_weights with repo_id:quant_type format."""
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
|
||||
|
||||
class MockTextConfig:
|
||||
max_position_embeddings = 4096
|
||||
sliding_window = None
|
||||
model_type = "qwen3"
|
||||
num_attention_heads = 32
|
||||
|
||||
mock_text_config = MockTextConfig()
|
||||
mock_hf_config.get_text_config.return_value = mock_text_config
|
||||
mock_hf_config.dtype = "bfloat16"
|
||||
mock_get_config.return_value = mock_hf_config
|
||||
|
||||
load_config = LoadConfig(load_format="gguf")
|
||||
loader = GGUFModelLoader(load_config)
|
||||
|
||||
mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
|
||||
)
|
||||
result = loader._prepare_weights(model_config)
|
||||
# The actual result will be the downloaded file path from mock
|
||||
assert result == "/downloaded/model-IQ1_S.gguf"
|
||||
mock_download_gguf.assert_called_once_with(
|
||||
"unsloth/Qwen3-0.6B-GGUF",
|
||||
"IQ1_S",
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
ignore_patterns=["original/**/*"],
|
||||
)
|
||||
|
||||
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
|
||||
@patch("vllm.config.model.get_config")
|
||||
@patch("vllm.config.model.is_gguf", return_value=False)
|
||||
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
|
||||
@patch("os.path.isfile", return_value=False)
|
||||
def test_prepare_weights_invalid_format(
|
||||
self,
|
||||
mock_isfile,
|
||||
mock_check_gguf,
|
||||
mock_is_gguf,
|
||||
mock_get_config,
|
||||
mock_get_image_config,
|
||||
):
|
||||
"""Test _prepare_weights with invalid format."""
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
|
||||
|
||||
class MockTextConfig:
|
||||
max_position_embeddings = 4096
|
||||
sliding_window = None
|
||||
model_type = "qwen3"
|
||||
num_attention_heads = 32
|
||||
|
||||
mock_text_config = MockTextConfig()
|
||||
mock_hf_config.get_text_config.return_value = mock_text_config
|
||||
mock_hf_config.dtype = "bfloat16"
|
||||
mock_get_config.return_value = mock_hf_config
|
||||
|
||||
load_config = LoadConfig(load_format="gguf")
|
||||
loader = GGUFModelLoader(load_config)
|
||||
|
||||
# Create ModelConfig with a valid repo_id to avoid validation errors
|
||||
# Then test _prepare_weights with invalid format
|
||||
model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
|
||||
# Manually set model to invalid format after creation
|
||||
model_config.model = "invalid-format"
|
||||
with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
|
||||
loader._prepare_weights(model_config)
|
||||
@ -1,11 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.utils import (
|
||||
is_cloud_storage,
|
||||
is_gcs,
|
||||
is_gguf,
|
||||
is_remote_gguf,
|
||||
is_s3,
|
||||
split_remote_gguf,
|
||||
)
|
||||
|
||||
|
||||
@ -28,3 +34,143 @@ def test_is_cloud_storage():
|
||||
assert is_cloud_storage("s3://model-path/path-to-model")
|
||||
assert not is_cloud_storage("/unix/local/path")
|
||||
assert not is_cloud_storage("nfs://nfs-fqdn.local")
|
||||
|
||||
|
||||
class TestIsRemoteGGUF:
|
||||
"""Test is_remote_gguf utility function."""
|
||||
|
||||
def test_is_remote_gguf_with_colon_and_slash(self):
|
||||
"""Test is_remote_gguf with repo_id:quant_type format."""
|
||||
# Valid quant types
|
||||
assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||
assert is_remote_gguf("user/repo:Q2_K")
|
||||
assert is_remote_gguf("repo/model:Q4_K")
|
||||
assert is_remote_gguf("repo/model:Q8_0")
|
||||
|
||||
# Invalid quant types should return False
|
||||
assert not is_remote_gguf("repo/model:quant")
|
||||
assert not is_remote_gguf("repo/model:INVALID")
|
||||
assert not is_remote_gguf("repo/model:invalid_type")
|
||||
|
||||
def test_is_remote_gguf_without_colon(self):
|
||||
"""Test is_remote_gguf without colon."""
|
||||
assert not is_remote_gguf("repo/model")
|
||||
assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
|
||||
|
||||
def test_is_remote_gguf_without_slash(self):
|
||||
"""Test is_remote_gguf without slash."""
|
||||
assert not is_remote_gguf("model.gguf")
|
||||
# Even with valid quant_type, no slash means not remote GGUF
|
||||
assert not is_remote_gguf("model:IQ1_S")
|
||||
assert not is_remote_gguf("model:quant")
|
||||
|
||||
def test_is_remote_gguf_local_path(self):
|
||||
"""Test is_remote_gguf with local file path."""
|
||||
assert not is_remote_gguf("/path/to/model.gguf")
|
||||
assert not is_remote_gguf("./model.gguf")
|
||||
|
||||
def test_is_remote_gguf_with_path_object(self):
|
||||
"""Test is_remote_gguf with Path object."""
|
||||
assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
|
||||
assert not is_remote_gguf(Path("repo/model"))
|
||||
|
||||
def test_is_remote_gguf_with_http_https(self):
|
||||
"""Test is_remote_gguf with HTTP/HTTPS URLs."""
|
||||
# HTTP/HTTPS URLs should return False even with valid quant_type
|
||||
assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
|
||||
assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
|
||||
assert not is_remote_gguf("http://repo/model:Q4_K")
|
||||
assert not is_remote_gguf("https://repo/model:Q8_0")
|
||||
|
||||
def test_is_remote_gguf_with_cloud_storage(self):
|
||||
"""Test is_remote_gguf with cloud storage paths."""
|
||||
# Cloud storage paths should return False even with valid quant_type
|
||||
assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
|
||||
assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
|
||||
assert not is_remote_gguf("s3://repo/model:Q4_K")
|
||||
assert not is_remote_gguf("gs://repo/model:Q8_0")
|
||||
|
||||
|
||||
class TestSplitRemoteGGUF:
|
||||
"""Test split_remote_gguf utility function."""
|
||||
|
||||
def test_split_remote_gguf_valid(self):
|
||||
"""Test split_remote_gguf with valid repo_id:quant_type format."""
|
||||
repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
|
||||
assert quant_type == "IQ1_S"
|
||||
|
||||
repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
|
||||
assert repo_id == "repo/model"
|
||||
assert quant_type == "Q2_K"
|
||||
|
||||
def test_split_remote_gguf_with_path_object(self):
|
||||
"""Test split_remote_gguf with Path object."""
|
||||
repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
|
||||
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
|
||||
assert quant_type == "IQ1_S"
|
||||
|
||||
def test_split_remote_gguf_invalid(self):
|
||||
"""Test split_remote_gguf with invalid format."""
|
||||
# Invalid format (no colon) - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("repo/model")
|
||||
|
||||
# Invalid quant type - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("repo/model:INVALID_TYPE")
|
||||
|
||||
# HTTP URL - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("http://repo/model:IQ1_S")
|
||||
|
||||
# Cloud storage - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("s3://bucket/repo/model:Q2_K")
|
||||
|
||||
|
||||
class TestIsGGUF:
|
||||
"""Test is_gguf utility function."""
|
||||
|
||||
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
|
||||
def test_is_gguf_with_local_file(self, mock_check_gguf):
|
||||
"""Test is_gguf with local GGUF file."""
|
||||
assert is_gguf("/path/to/model.gguf")
|
||||
assert is_gguf("./model.gguf")
|
||||
|
||||
def test_is_gguf_with_remote_gguf(self):
|
||||
"""Test is_gguf with remote GGUF format."""
|
||||
# Valid remote GGUF format (repo_id:quant_type with valid quant_type)
|
||||
assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||
assert is_gguf("repo/model:Q2_K")
|
||||
assert is_gguf("repo/model:Q4_K")
|
||||
|
||||
# Invalid quant_type should return False
|
||||
assert not is_gguf("repo/model:quant")
|
||||
assert not is_gguf("repo/model:INVALID")
|
||||
|
||||
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
|
||||
def test_is_gguf_false(self, mock_check_gguf):
|
||||
"""Test is_gguf returns False for non-GGUF models."""
|
||||
assert not is_gguf("unsloth/Qwen3-0.6B")
|
||||
assert not is_gguf("repo/model")
|
||||
assert not is_gguf("model")
|
||||
|
||||
def test_is_gguf_edge_cases(self):
|
||||
"""Test is_gguf with edge cases."""
|
||||
# Empty string
|
||||
assert not is_gguf("")
|
||||
|
||||
# Only colon, no slash (even with valid quant_type)
|
||||
assert not is_gguf("model:IQ1_S")
|
||||
|
||||
# Only slash, no colon
|
||||
assert not is_gguf("repo/model")
|
||||
|
||||
# HTTP/HTTPS URLs
|
||||
assert not is_gguf("http://repo/model:IQ1_S")
|
||||
assert not is_gguf("https://repo/model:Q2_K")
|
||||
|
||||
# Cloud storage
|
||||
assert not is_gguf("s3://bucket/repo/model:IQ1_S")
|
||||
assert not is_gguf("gs://bucket/repo/model:Q2_K")
|
||||
|
||||
@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import (
|
||||
maybe_patch_hf_config_from_gguf,
|
||||
)
|
||||
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
|
||||
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
|
||||
from vllm.transformers_utils.utils import (
|
||||
is_gguf,
|
||||
is_remote_gguf,
|
||||
maybe_model_redirect,
|
||||
split_remote_gguf,
|
||||
)
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
|
||||
@ -440,7 +445,8 @@ class ModelConfig:
|
||||
self.model = maybe_model_redirect(self.model)
|
||||
# The tokenizer is consistent with the model by default.
|
||||
if self.tokenizer is None:
|
||||
if check_gguf_file(self.model):
|
||||
# Check if this is a GGUF model (either local file or remote GGUF)
|
||||
if is_gguf(self.model):
|
||||
raise ValueError(
|
||||
"Using a tokenizer is mandatory when loading a GGUF model. "
|
||||
"Please specify the tokenizer path or name using the "
|
||||
@ -832,7 +838,10 @@ class ModelConfig:
|
||||
self.tokenizer = object_storage_tokenizer.dir
|
||||
|
||||
def _get_encoder_config(self):
|
||||
return get_sentence_transformer_tokenizer_config(self.model, self.revision)
|
||||
model = self.model
|
||||
if is_remote_gguf(model):
|
||||
model, _ = split_remote_gguf(model)
|
||||
return get_sentence_transformer_tokenizer_config(model, self.revision)
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||
|
||||
@ -86,7 +86,7 @@ from vllm.transformers_utils.config import (
|
||||
is_interleaved,
|
||||
maybe_override_with_speculators,
|
||||
)
|
||||
from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage
|
||||
from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.network_utils import get_ip
|
||||
@ -1148,8 +1148,8 @@ class EngineArgs:
|
||||
return engine_args
|
||||
|
||||
def create_model_config(self) -> ModelConfig:
|
||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||
if check_gguf_file(self.model):
|
||||
# gguf file needs a specific model loader
|
||||
if is_gguf(self.model):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_gguf,
|
||||
get_gguf_extra_tensor_names,
|
||||
get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator,
|
||||
@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str):
|
||||
def _prepare_weights(self, model_config: ModelConfig):
|
||||
model_name_or_path = model_config.model
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
# for raw HTTPS link
|
||||
@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)"
|
||||
# repo_id:quant_type
|
||||
elif "/" in model_name_or_path and ":" in model_name_or_path:
|
||||
repo_id, quant_type = model_name_or_path.rsplit(":", 1)
|
||||
return download_gguf(
|
||||
repo_id,
|
||||
quant_type,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
revision=model_config.revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, raw URL, <repo_id>/<filename>.gguf, "
|
||||
"or <repo_id>:<quant_type>)"
|
||||
)
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
gguf_to_hf_name_map: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
weight_type_map = get_gguf_weight_type_map(
|
||||
model_config.model, gguf_to_hf_name_map
|
||||
model_name_or_path, gguf_to_hf_name_map
|
||||
)
|
||||
is_multimodal = hasattr(model_config.hf_config, "vision_config")
|
||||
if is_multimodal:
|
||||
@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model)
|
||||
self._prepare_weights(model_config)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
|
||||
@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
|
||||
@ -369,6 +369,52 @@ def get_sparse_attention_config(
|
||||
return config
|
||||
|
||||
|
||||
def download_gguf(
|
||||
repo_id: str,
|
||||
quant_type: str,
|
||||
cache_dir: str | None = None,
|
||||
revision: str | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
) -> str:
|
||||
# Use patterns that snapshot_download can handle directly
|
||||
# Patterns to match:
|
||||
# - *-{quant_type}.gguf (root)
|
||||
# - *-{quant_type}-*.gguf (root sharded)
|
||||
# - */*-{quant_type}.gguf (subdir)
|
||||
# - */*-{quant_type}-*.gguf (subdir sharded)
|
||||
allow_patterns = [
|
||||
f"*-{quant_type}.gguf",
|
||||
f"*-{quant_type}-*.gguf",
|
||||
f"*/*-{quant_type}.gguf",
|
||||
f"*/*-{quant_type}-*.gguf",
|
||||
]
|
||||
|
||||
# Use download_weights_from_hf which handles caching and downloading
|
||||
folder = download_weights_from_hf(
|
||||
model_name_or_path=repo_id,
|
||||
cache_dir=cache_dir,
|
||||
allow_patterns=allow_patterns,
|
||||
revision=revision,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
# Find the downloaded file(s) in the folder
|
||||
local_files = []
|
||||
for pattern in allow_patterns:
|
||||
# Convert pattern to glob pattern for local filesystem
|
||||
glob_pattern = os.path.join(folder, pattern)
|
||||
local_files.extend(glob.glob(glob_pattern))
|
||||
|
||||
if not local_files:
|
||||
raise ValueError(
|
||||
f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
|
||||
)
|
||||
|
||||
# Sort to ensure consistent ordering (prefer non-sharded files)
|
||||
local_files.sort(key=lambda x: (x.count("-"), x))
|
||||
return local_files[0]
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: str | None,
|
||||
|
||||
@ -42,7 +42,10 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||
from vllm.transformers_utils.utils import (
|
||||
check_gguf_file,
|
||||
is_gguf,
|
||||
is_remote_gguf,
|
||||
parse_safetensors_file_metadata,
|
||||
split_remote_gguf,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
@ -629,10 +632,12 @@ def maybe_override_with_speculators(
|
||||
Returns:
|
||||
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
|
||||
"""
|
||||
is_gguf = check_gguf_file(model)
|
||||
if is_gguf:
|
||||
if check_gguf_file(model):
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
gguf_model_repo = Path(model).parent
|
||||
elif is_remote_gguf(model):
|
||||
repo_id, _ = split_remote_gguf(model)
|
||||
gguf_model_repo = Path(repo_id)
|
||||
else:
|
||||
gguf_model_repo = None
|
||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
@ -678,10 +683,18 @@ def get_config(
|
||||
) -> PretrainedConfig:
|
||||
# Separate model folder from file path for GGUF models
|
||||
|
||||
is_gguf = check_gguf_file(model)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
model = Path(model).parent
|
||||
_is_gguf = is_gguf(model)
|
||||
_is_remote_gguf = is_remote_gguf(model)
|
||||
if _is_gguf:
|
||||
if check_gguf_file(model):
|
||||
# Local GGUF file
|
||||
kwargs["gguf_file"] = Path(model).name
|
||||
model = Path(model).parent
|
||||
elif _is_remote_gguf:
|
||||
# Remote GGUF - extract repo_id from repo_id:quant_type format
|
||||
# The actual GGUF file will be downloaded later by GGUFModelLoader
|
||||
# Keep model as repo_id:quant_type for download, but use repo_id for config
|
||||
model, _ = split_remote_gguf(model)
|
||||
|
||||
if config_format == "auto":
|
||||
try:
|
||||
@ -689,10 +702,25 @@ def get_config(
|
||||
# Transformers implementation.
|
||||
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
|
||||
config_format = "mistral"
|
||||
elif is_gguf or file_or_path_exists(
|
||||
elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
|
||||
model, HF_CONFIG_NAME, revision=revision
|
||||
):
|
||||
config_format = "hf"
|
||||
# Remote GGUF models must have config.json in repo,
|
||||
# otherwise the config can't be parsed correctly.
|
||||
# FIXME(Isotr0py): Support remote GGUF repos without config.json
|
||||
elif _is_remote_gguf and not file_or_path_exists(
|
||||
model, HF_CONFIG_NAME, revision=revision
|
||||
):
|
||||
err_msg = (
|
||||
"Could not find config.json for remote GGUF model repo. "
|
||||
"To load remote GGUF model through `<repo_id>:<quant_type>`, "
|
||||
"ensure your model has config.json (HF format) file. "
|
||||
"Otherwise please specify --hf-config-path <original_repo> "
|
||||
"in engine args to fetch config from unquantized hf model."
|
||||
)
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not detect config format for no config file found. "
|
||||
@ -713,9 +741,6 @@ def get_config(
|
||||
"'config.json'.\n"
|
||||
" - For Mistral models: ensure the presence of a "
|
||||
"'params.json'.\n"
|
||||
"3. For GGUF: pass the local path of the GGUF checkpoint.\n"
|
||||
" Loading GGUF from a remote repo directly is not yet "
|
||||
"supported.\n"
|
||||
).format(model=model)
|
||||
|
||||
raise ValueError(error_message) from e
|
||||
@ -729,7 +754,7 @@ def get_config(
|
||||
**kwargs,
|
||||
)
|
||||
# Special architecture mapping check for GGUF models
|
||||
if is_gguf:
|
||||
if _is_gguf:
|
||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||
@ -889,6 +914,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
|
||||
A dictionary containing the pooling type and whether
|
||||
normalization is used, or None if no pooling configuration is found.
|
||||
"""
|
||||
if is_remote_gguf(model):
|
||||
model, _ = split_remote_gguf(model)
|
||||
|
||||
modules_file_name = "modules.json"
|
||||
|
||||
@ -1108,6 +1135,8 @@ def get_hf_image_processor_config(
|
||||
# Separate model folder from file path for GGUF models
|
||||
if check_gguf_file(model):
|
||||
model = Path(model).parent
|
||||
elif is_remote_gguf(model):
|
||||
model, _ = split_remote_gguf(model)
|
||||
return get_image_processor_config(
|
||||
model, token=hf_token, revision=revision, **kwargs
|
||||
)
|
||||
|
||||
@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
|
||||
from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
|
||||
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -236,8 +236,8 @@ def cached_processor_from_config(
|
||||
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
) -> _P:
|
||||
if check_gguf_file(model_config.model):
|
||||
assert not check_gguf_file(model_config.tokenizer), (
|
||||
if is_gguf(model_config.model):
|
||||
assert not is_gguf(model_config.tokenizer), (
|
||||
"For multimodal GGUF models, the original tokenizer "
|
||||
"should be used to correctly load processor."
|
||||
)
|
||||
@ -350,8 +350,8 @@ def cached_image_processor_from_config(
|
||||
model_config: "ModelConfig",
|
||||
**kwargs: Any,
|
||||
):
|
||||
if check_gguf_file(model_config.model):
|
||||
assert not check_gguf_file(model_config.tokenizer), (
|
||||
if is_gguf(model_config.model):
|
||||
assert not is_gguf(model_config.tokenizer), (
|
||||
"For multimodal GGUF models, the original tokenizer "
|
||||
"should be used to correctly load image processor."
|
||||
)
|
||||
|
||||
@ -20,7 +20,12 @@ from vllm.transformers_utils.config import (
|
||||
list_filtered_repo_files,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.transformers_utils.utils import (
|
||||
check_gguf_file,
|
||||
is_gguf,
|
||||
is_remote_gguf,
|
||||
split_remote_gguf,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -180,10 +185,12 @@ def get_tokenizer(
|
||||
kwargs["truncation_side"] = "left"
|
||||
|
||||
# Separate model folder from file path for GGUF models
|
||||
is_gguf = check_gguf_file(tokenizer_name)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
if is_gguf(tokenizer_name):
|
||||
if check_gguf_file(tokenizer_name):
|
||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
elif is_remote_gguf(tokenizer_name):
|
||||
tokenizer_name, _ = split_remote_gguf(tokenizer_name)
|
||||
|
||||
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
|
||||
# first to use official Mistral tokenizer if possible.
|
||||
|
||||
@ -9,6 +9,8 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gguf import GGMLQuantizationType
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@cache
|
||||
def is_remote_gguf(model: str | Path) -> bool:
|
||||
"""Check if the model is a remote GGUF model."""
|
||||
model = str(model)
|
||||
return (
|
||||
(not is_cloud_storage(model))
|
||||
and (not model.startswith(("http://", "https://")))
|
||||
and ("/" in model and ":" in model)
|
||||
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
|
||||
)
|
||||
|
||||
|
||||
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
|
||||
"""Check if the quant type is a valid GGUF quant type."""
|
||||
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
|
||||
|
||||
|
||||
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
|
||||
"""Split the model into repo_id and quant type."""
|
||||
model = str(model)
|
||||
if is_remote_gguf(model):
|
||||
parts = model.rsplit(":", 1)
|
||||
return (parts[0], parts[1])
|
||||
raise ValueError(
|
||||
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
|
||||
"- It should be in repo_id:quant_type format.\n"
|
||||
"- Valid GGMLQuantizationType values: %s",
|
||||
model,
|
||||
GGMLQuantizationType._member_names_,
|
||||
)
|
||||
|
||||
|
||||
def is_gguf(model: str | Path) -> bool:
|
||||
"""Check if the model is a GGUF model.
|
||||
|
||||
Args:
|
||||
model: Model name, path, or Path object to check.
|
||||
|
||||
Returns:
|
||||
True if the model is a GGUF model, False otherwise.
|
||||
"""
|
||||
model = str(model)
|
||||
|
||||
# Check if it's a local GGUF file
|
||||
if check_gguf_file(model):
|
||||
return True
|
||||
|
||||
# Check if it's a remote GGUF model (repo_id:quant_type format)
|
||||
return is_remote_gguf(model)
|
||||
|
||||
|
||||
def modelscope_list_repo_files(
|
||||
repo_id: str,
|
||||
revision: str | None = None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user