[Feature]: Improve GGUF loading from HuggingFace user experience like repo_id:quant_type (#29137)

Signed-off-by: Injae Ryou <injaeryou@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Injae Ryou 2025-11-25 23:28:53 +09:00 committed by GitHub
parent 0231ce836a
commit 794029f012
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 579 additions and 36 deletions

View File

@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.weight_utils import download_gguf
class TestGGUFDownload:
"""Test GGUF model downloading functionality."""
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_single_file(self, mock_download):
"""Test downloading a single GGUF file."""
# Setup mock
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
# Mock glob to return a single file
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
)
result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
# Verify download_weights_from_hf was called with correct patterns
mock_download.assert_called_once_with(
model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
cache_dir=None,
allow_patterns=[
"*-IQ1_S.gguf",
"*-IQ1_S-*.gguf",
"*/*-IQ1_S.gguf",
"*/*-IQ1_S-*.gguf",
],
revision=None,
ignore_patterns=None,
)
# Verify result is the file path, not folder
assert result == f"{mock_folder}/model-IQ1_S.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_sharded_files(self, mock_download):
"""Test downloading sharded GGUF files."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
# Mock glob to return sharded files
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[
f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
]
if "Q2_K" in pattern
else []
)
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
# Should return the first file after sorting
assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_subdir(self, mock_download):
"""Test downloading GGUF files from subdirectory."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
if "Q2_K" in pattern or "**/*.gguf" in pattern
else []
)
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
@patch("glob.glob", return_value=[])
def test_download_gguf_no_files_found(self, mock_glob, mock_download):
"""Test error when no GGUF files are found."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
class TestGGUFModelLoader:
"""Test GGUFModelLoader class methods."""
@patch("os.path.isfile", return_value=True)
def test_prepare_weights_local_file(self, mock_isfile):
"""Test _prepare_weights with local file."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "/path/to/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/path/to/model.gguf"
mock_isfile.assert_called_once_with("/path/to/model.gguf")
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
"""Test _prepare_weights with HTTPS URL."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_hf_download.return_value = "/downloaded/model.gguf"
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "https://huggingface.co/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/downloaded/model.gguf"
mock_hf_download.assert_called_once_with(
url="https://huggingface.co/model.gguf"
)
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
"""Test _prepare_weights with repo_id/filename.gguf format."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_hf_download.return_value = "/downloaded/model.gguf"
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/downloaded/model.gguf"
mock_hf_download.assert_called_once_with(
repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
)
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
@patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
@patch("vllm.config.model.get_config")
@patch("vllm.config.model.is_gguf", return_value=True)
@patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_repo_quant_type(
self,
mock_isfile,
mock_download_gguf,
mock_is_gguf,
mock_get_config,
mock_file_exists,
mock_get_image_config,
):
"""Test _prepare_weights with repo_id:quant_type format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
mock_text_config = MockTextConfig()
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
model_config = ModelConfig(
model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
)
result = loader._prepare_weights(model_config)
# The actual result will be the downloaded file path from mock
assert result == "/downloaded/model-IQ1_S.gguf"
mock_download_gguf.assert_called_once_with(
"unsloth/Qwen3-0.6B-GGUF",
"IQ1_S",
cache_dir=None,
revision=None,
ignore_patterns=["original/**/*"],
)
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
@patch("vllm.config.model.get_config")
@patch("vllm.config.model.is_gguf", return_value=False)
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_invalid_format(
self,
mock_isfile,
mock_check_gguf,
mock_is_gguf,
mock_get_config,
mock_get_image_config,
):
"""Test _prepare_weights with invalid format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
mock_text_config = MockTextConfig()
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
# Create ModelConfig with a valid repo_id to avoid validation errors
# Then test _prepare_weights with invalid format
model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
# Manually set model to invalid format after creation
model_config.model = "invalid-format"
with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
loader._prepare_weights(model_config)

View File

@ -1,11 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from unittest.mock import patch
import pytest
from vllm.transformers_utils.utils import (
is_cloud_storage,
is_gcs,
is_gguf,
is_remote_gguf,
is_s3,
split_remote_gguf,
)
@ -28,3 +34,143 @@ def test_is_cloud_storage():
assert is_cloud_storage("s3://model-path/path-to-model")
assert not is_cloud_storage("/unix/local/path")
assert not is_cloud_storage("nfs://nfs-fqdn.local")
class TestIsRemoteGGUF:
"""Test is_remote_gguf utility function."""
def test_is_remote_gguf_with_colon_and_slash(self):
"""Test is_remote_gguf with repo_id:quant_type format."""
# Valid quant types
assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert is_remote_gguf("user/repo:Q2_K")
assert is_remote_gguf("repo/model:Q4_K")
assert is_remote_gguf("repo/model:Q8_0")
# Invalid quant types should return False
assert not is_remote_gguf("repo/model:quant")
assert not is_remote_gguf("repo/model:INVALID")
assert not is_remote_gguf("repo/model:invalid_type")
def test_is_remote_gguf_without_colon(self):
"""Test is_remote_gguf without colon."""
assert not is_remote_gguf("repo/model")
assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
def test_is_remote_gguf_without_slash(self):
"""Test is_remote_gguf without slash."""
assert not is_remote_gguf("model.gguf")
# Even with valid quant_type, no slash means not remote GGUF
assert not is_remote_gguf("model:IQ1_S")
assert not is_remote_gguf("model:quant")
def test_is_remote_gguf_local_path(self):
"""Test is_remote_gguf with local file path."""
assert not is_remote_gguf("/path/to/model.gguf")
assert not is_remote_gguf("./model.gguf")
def test_is_remote_gguf_with_path_object(self):
"""Test is_remote_gguf with Path object."""
assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
assert not is_remote_gguf(Path("repo/model"))
def test_is_remote_gguf_with_http_https(self):
"""Test is_remote_gguf with HTTP/HTTPS URLs."""
# HTTP/HTTPS URLs should return False even with valid quant_type
assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
assert not is_remote_gguf("http://repo/model:Q4_K")
assert not is_remote_gguf("https://repo/model:Q8_0")
def test_is_remote_gguf_with_cloud_storage(self):
"""Test is_remote_gguf with cloud storage paths."""
# Cloud storage paths should return False even with valid quant_type
assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
assert not is_remote_gguf("s3://repo/model:Q4_K")
assert not is_remote_gguf("gs://repo/model:Q8_0")
class TestSplitRemoteGGUF:
"""Test split_remote_gguf utility function."""
def test_split_remote_gguf_valid(self):
"""Test split_remote_gguf with valid repo_id:quant_type format."""
repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
assert quant_type == "IQ1_S"
repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
assert repo_id == "repo/model"
assert quant_type == "Q2_K"
def test_split_remote_gguf_with_path_object(self):
"""Test split_remote_gguf with Path object."""
repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
assert quant_type == "IQ1_S"
def test_split_remote_gguf_invalid(self):
"""Test split_remote_gguf with invalid format."""
# Invalid format (no colon) - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("repo/model")
# Invalid quant type - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("repo/model:INVALID_TYPE")
# HTTP URL - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("http://repo/model:IQ1_S")
# Cloud storage - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("s3://bucket/repo/model:Q2_K")
class TestIsGGUF:
"""Test is_gguf utility function."""
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
def test_is_gguf_with_local_file(self, mock_check_gguf):
"""Test is_gguf with local GGUF file."""
assert is_gguf("/path/to/model.gguf")
assert is_gguf("./model.gguf")
def test_is_gguf_with_remote_gguf(self):
"""Test is_gguf with remote GGUF format."""
# Valid remote GGUF format (repo_id:quant_type with valid quant_type)
assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert is_gguf("repo/model:Q2_K")
assert is_gguf("repo/model:Q4_K")
# Invalid quant_type should return False
assert not is_gguf("repo/model:quant")
assert not is_gguf("repo/model:INVALID")
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
def test_is_gguf_false(self, mock_check_gguf):
"""Test is_gguf returns False for non-GGUF models."""
assert not is_gguf("unsloth/Qwen3-0.6B")
assert not is_gguf("repo/model")
assert not is_gguf("model")
def test_is_gguf_edge_cases(self):
"""Test is_gguf with edge cases."""
# Empty string
assert not is_gguf("")
# Only colon, no slash (even with valid quant_type)
assert not is_gguf("model:IQ1_S")
# Only slash, no colon
assert not is_gguf("repo/model")
# HTTP/HTTPS URLs
assert not is_gguf("http://repo/model:IQ1_S")
assert not is_gguf("https://repo/model:Q2_K")
# Cloud storage
assert not is_gguf("s3://bucket/repo/model:IQ1_S")
assert not is_gguf("gs://bucket/repo/model:Q2_K")

View File

@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
from vllm.transformers_utils.utils import (
is_gguf,
is_remote_gguf,
maybe_model_redirect,
split_remote_gguf,
)
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
@ -440,7 +445,8 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
if check_gguf_file(self.model):
# Check if this is a GGUF model (either local file or remote GGUF)
if is_gguf(self.model):
raise ValueError(
"Using a tokenizer is mandatory when loading a GGUF model. "
"Please specify the tokenizer path or name using the "
@ -832,7 +838,10 @@ class ModelConfig:
self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(self.model, self.revision)
model = self.model
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
return get_sentence_transformer_tokenizer_config(model, self.revision)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())

View File

@ -86,7 +86,7 @@ from vllm.transformers_utils.config import (
is_interleaved,
maybe_override_with_speculators,
)
from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage
from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
@ -1148,8 +1148,8 @@ class EngineArgs:
return engine_args
def create_model_config(self) -> ModelConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
# gguf file needs a specific model loader
if is_gguf(self.model):
self.quantization = self.load_format = "gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless

View File

@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading,
)
from vllm.model_executor.model_loader.weight_utils import (
download_gguf,
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
f"load format {load_config.load_format}"
)
def _prepare_weights(self, model_name_or_path: str):
def _prepare_weights(self, model_config: ModelConfig):
model_name_or_path = model_config.model
if os.path.isfile(model_name_or_path):
return model_name_or_path
# for raw HTTPS link
@ -55,10 +57,21 @@ class GGUFModelLoader(BaseModelLoader):
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename)
else:
# repo_id:quant_type
elif "/" in model_name_or_path and ":" in model_name_or_path:
repo_id, quant_type = model_name_or_path.rsplit(":", 1)
return download_gguf(
repo_id,
quant_type,
cache_dir=self.load_config.download_dir,
revision=model_config.revision,
ignore_patterns=self.load_config.ignore_patterns,
)
raise ValueError(
f"Unrecognised GGUF reference: {model_name_or_path} "
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)"
"(expected local file, raw URL, <repo_id>/<filename>.gguf, "
"or <repo_id>:<quant_type>)"
)
def _get_gguf_weights_map(self, model_config: ModelConfig):
@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map(
model_config.model, gguf_to_hf_name_map
model_name_or_path, gguf_to_hf_name_map
)
is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal:
@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
self._prepare_weights(model_config)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model)
local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model)
local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(

View File

@ -369,6 +369,52 @@ def get_sparse_attention_config(
return config
def download_gguf(
repo_id: str,
quant_type: str,
cache_dir: str | None = None,
revision: str | None = None,
ignore_patterns: str | list[str] | None = None,
) -> str:
# Use patterns that snapshot_download can handle directly
# Patterns to match:
# - *-{quant_type}.gguf (root)
# - *-{quant_type}-*.gguf (root sharded)
# - */*-{quant_type}.gguf (subdir)
# - */*-{quant_type}-*.gguf (subdir sharded)
allow_patterns = [
f"*-{quant_type}.gguf",
f"*-{quant_type}-*.gguf",
f"*/*-{quant_type}.gguf",
f"*/*-{quant_type}-*.gguf",
]
# Use download_weights_from_hf which handles caching and downloading
folder = download_weights_from_hf(
model_name_or_path=repo_id,
cache_dir=cache_dir,
allow_patterns=allow_patterns,
revision=revision,
ignore_patterns=ignore_patterns,
)
# Find the downloaded file(s) in the folder
local_files = []
for pattern in allow_patterns:
# Convert pattern to glob pattern for local filesystem
glob_pattern = os.path.join(folder, pattern)
local_files.extend(glob.glob(glob_pattern))
if not local_files:
raise ValueError(
f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
)
# Sort to ensure consistent ordering (prefer non-sharded files)
local_files.sort(key=lambda x: (x.count("-"), x))
return local_files[0]
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: str | None,

View File

@ -42,7 +42,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
parse_safetensors_file_metadata,
split_remote_gguf,
)
if envs.VLLM_USE_MODELSCOPE:
@ -629,10 +632,12 @@ def maybe_override_with_speculators(
Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
"""
is_gguf = check_gguf_file(model)
if is_gguf:
if check_gguf_file(model):
kwargs["gguf_file"] = Path(model).name
gguf_model_repo = Path(model).parent
elif is_remote_gguf(model):
repo_id, _ = split_remote_gguf(model)
gguf_model_repo = Path(repo_id)
else:
gguf_model_repo = None
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
@ -678,10 +683,18 @@ def get_config(
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(model)
if is_gguf:
_is_gguf = is_gguf(model)
_is_remote_gguf = is_remote_gguf(model)
if _is_gguf:
if check_gguf_file(model):
# Local GGUF file
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent
elif _is_remote_gguf:
# Remote GGUF - extract repo_id from repo_id:quant_type format
# The actual GGUF file will be downloaded later by GGUFModelLoader
# Keep model as repo_id:quant_type for download, but use repo_id for config
model, _ = split_remote_gguf(model)
if config_format == "auto":
try:
@ -689,10 +702,25 @@ def get_config(
# Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral"
elif is_gguf or file_or_path_exists(
elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
config_format = "hf"
# Remote GGUF models must have config.json in repo,
# otherwise the config can't be parsed correctly.
# FIXME(Isotr0py): Support remote GGUF repos without config.json
elif _is_remote_gguf and not file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
err_msg = (
"Could not find config.json for remote GGUF model repo. "
"To load remote GGUF model through `<repo_id>:<quant_type>`, "
"ensure your model has config.json (HF format) file. "
"Otherwise please specify --hf-config-path <original_repo> "
"in engine args to fetch config from unquantized hf model."
)
logger.error(err_msg)
raise ValueError(err_msg)
else:
raise ValueError(
"Could not detect config format for no config file found. "
@ -713,9 +741,6 @@ def get_config(
"'config.json'.\n"
" - For Mistral models: ensure the presence of a "
"'params.json'.\n"
"3. For GGUF: pass the local path of the GGUF checkpoint.\n"
" Loading GGUF from a remote repo directly is not yet "
"supported.\n"
).format(model=model)
raise ValueError(error_message) from e
@ -729,7 +754,7 @@ def get_config(
**kwargs,
)
# Special architecture mapping check for GGUF models
if is_gguf:
if _is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
@ -889,6 +914,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
A dictionary containing the pooling type and whether
normalization is used, or None if no pooling configuration is found.
"""
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
modules_file_name = "modules.json"
@ -1108,6 +1135,8 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if check_gguf_file(model):
model = Path(model).parent
elif is_remote_gguf(model):
model, _ = split_remote_gguf(model)
return get_image_processor_config(
model, token=hf_token, revision=revision, **kwargs
)

View File

@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
@ -236,8 +236,8 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
if check_gguf_file(model_config.model):
assert not check_gguf_file(model_config.tokenizer), (
if is_gguf(model_config.model):
assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor."
)
@ -350,8 +350,8 @@ def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
if check_gguf_file(model_config.model):
assert not check_gguf_file(model_config.tokenizer), (
if is_gguf(model_config.model):
assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor."
)

View File

@ -20,7 +20,12 @@ from vllm.transformers_utils.config import (
list_filtered_repo_files,
)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
split_remote_gguf,
)
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -180,10 +185,12 @@ def get_tokenizer(
kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
if is_gguf(tokenizer_name):
if check_gguf_file(tokenizer_name):
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent
elif is_remote_gguf(tokenizer_name):
tokenizer_name, _ = split_remote_gguf(tokenizer_name)
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
# first to use official Mistral tokenizer if possible.

View File

@ -9,6 +9,8 @@ from os import PathLike
from pathlib import Path
from typing import Any
from gguf import GGMLQuantizationType
import vllm.envs as envs
from vllm.logger import init_logger
@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
model = str(model)
return (
(not is_cloud_storage(model))
and (not model.startswith(("http://", "https://")))
and ("/" in model and ":" in model)
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
)
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
"- It should be in repo_id:quant_type format.\n"
"- Valid GGMLQuantizationType values: %s",
model,
GGMLQuantizationType._member_names_,
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def modelscope_list_repo_files(
repo_id: str,
revision: str | None = None,