diff --git a/tests/models/test_gguf_download.py b/tests/models/test_gguf_download.py new file mode 100644 index 000000000000..155768ac9bff --- /dev/null +++ b/tests/models/test_gguf_download.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest + +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader +from vllm.model_executor.model_loader.weight_utils import download_gguf + + +class TestGGUFDownload: + """Test GGUF model downloading functionality.""" + + @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf") + def test_download_gguf_single_file(self, mock_download): + """Test downloading a single GGUF file.""" + # Setup mock + mock_folder = "/tmp/mock_cache" + mock_download.return_value = mock_folder + + # Mock glob to return a single file + with patch("glob.glob") as mock_glob: + mock_glob.side_effect = lambda pattern, **kwargs: ( + [f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else [] + ) + + result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S") + + # Verify download_weights_from_hf was called with correct patterns + mock_download.assert_called_once_with( + model_name_or_path="unsloth/Qwen3-0.6B-GGUF", + cache_dir=None, + allow_patterns=[ + "*-IQ1_S.gguf", + "*-IQ1_S-*.gguf", + "*/*-IQ1_S.gguf", + "*/*-IQ1_S-*.gguf", + ], + revision=None, + ignore_patterns=None, + ) + + # Verify result is the file path, not folder + assert result == f"{mock_folder}/model-IQ1_S.gguf" + + @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf") + def test_download_gguf_sharded_files(self, mock_download): + """Test downloading sharded GGUF files.""" + mock_folder = "/tmp/mock_cache" + mock_download.return_value = mock_folder + + # Mock glob to return sharded files + with patch("glob.glob") as mock_glob: + mock_glob.side_effect = lambda pattern, **kwargs: ( + [ + f"{mock_folder}/model-Q2_K-00001-of-00002.gguf", + f"{mock_folder}/model-Q2_K-00002-of-00002.gguf", + ] + if "Q2_K" in pattern + else [] + ) + + result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K") + + # Should return the first file after sorting + assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf" + + @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf") + def test_download_gguf_subdir(self, mock_download): + """Test downloading GGUF files from subdirectory.""" + mock_folder = "/tmp/mock_cache" + mock_download.return_value = mock_folder + + with patch("glob.glob") as mock_glob: + mock_glob.side_effect = lambda pattern, **kwargs: ( + [f"{mock_folder}/Q2_K/model-Q2_K.gguf"] + if "Q2_K" in pattern or "**/*.gguf" in pattern + else [] + ) + + result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K") + + assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf" + + @patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf") + @patch("glob.glob", return_value=[]) + def test_download_gguf_no_files_found(self, mock_glob, mock_download): + """Test error when no GGUF files are found.""" + mock_folder = "/tmp/mock_cache" + mock_download.return_value = mock_folder + + with pytest.raises(ValueError, match="Downloaded GGUF files not found"): + download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S") + + +class TestGGUFModelLoader: + """Test GGUFModelLoader class methods.""" + + @patch("os.path.isfile", return_value=True) + def test_prepare_weights_local_file(self, mock_isfile): + """Test _prepare_weights with local file.""" + load_config = LoadConfig(load_format="gguf") + loader = GGUFModelLoader(load_config) + + # Create a simple mock ModelConfig with only the model attribute + model_config = MagicMock() + model_config.model = "/path/to/model.gguf" + + result = loader._prepare_weights(model_config) + assert result == "/path/to/model.gguf" + mock_isfile.assert_called_once_with("/path/to/model.gguf") + + @patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download") + @patch("os.path.isfile", return_value=False) + def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download): + """Test _prepare_weights with HTTPS URL.""" + load_config = LoadConfig(load_format="gguf") + loader = GGUFModelLoader(load_config) + + mock_hf_download.return_value = "/downloaded/model.gguf" + + # Create a simple mock ModelConfig with only the model attribute + model_config = MagicMock() + model_config.model = "https://huggingface.co/model.gguf" + + result = loader._prepare_weights(model_config) + assert result == "/downloaded/model.gguf" + mock_hf_download.assert_called_once_with( + url="https://huggingface.co/model.gguf" + ) + + @patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download") + @patch("os.path.isfile", return_value=False) + def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download): + """Test _prepare_weights with repo_id/filename.gguf format.""" + load_config = LoadConfig(load_format="gguf") + loader = GGUFModelLoader(load_config) + + mock_hf_download.return_value = "/downloaded/model.gguf" + + # Create a simple mock ModelConfig with only the model attribute + model_config = MagicMock() + model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf" + + result = loader._prepare_weights(model_config) + assert result == "/downloaded/model.gguf" + mock_hf_download.assert_called_once_with( + repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf" + ) + + @patch("vllm.config.model.get_hf_image_processor_config", return_value=None) + @patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True) + @patch("vllm.config.model.get_config") + @patch("vllm.config.model.is_gguf", return_value=True) + @patch("vllm.model_executor.model_loader.gguf_loader.download_gguf") + @patch("os.path.isfile", return_value=False) + def test_prepare_weights_repo_quant_type( + self, + mock_isfile, + mock_download_gguf, + mock_is_gguf, + mock_get_config, + mock_file_exists, + mock_get_image_config, + ): + """Test _prepare_weights with repo_id:quant_type format.""" + mock_hf_config = MagicMock() + mock_hf_config.architectures = ["Qwen3ForCausalLM"] + + class MockTextConfig: + max_position_embeddings = 4096 + sliding_window = None + model_type = "qwen3" + num_attention_heads = 32 + + mock_text_config = MockTextConfig() + mock_hf_config.get_text_config.return_value = mock_text_config + mock_hf_config.dtype = "bfloat16" + mock_get_config.return_value = mock_hf_config + + load_config = LoadConfig(load_format="gguf") + loader = GGUFModelLoader(load_config) + + mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf" + + model_config = ModelConfig( + model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B" + ) + result = loader._prepare_weights(model_config) + # The actual result will be the downloaded file path from mock + assert result == "/downloaded/model-IQ1_S.gguf" + mock_download_gguf.assert_called_once_with( + "unsloth/Qwen3-0.6B-GGUF", + "IQ1_S", + cache_dir=None, + revision=None, + ignore_patterns=["original/**/*"], + ) + + @patch("vllm.config.model.get_hf_image_processor_config", return_value=None) + @patch("vllm.config.model.get_config") + @patch("vllm.config.model.is_gguf", return_value=False) + @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False) + @patch("os.path.isfile", return_value=False) + def test_prepare_weights_invalid_format( + self, + mock_isfile, + mock_check_gguf, + mock_is_gguf, + mock_get_config, + mock_get_image_config, + ): + """Test _prepare_weights with invalid format.""" + mock_hf_config = MagicMock() + mock_hf_config.architectures = ["Qwen3ForCausalLM"] + + class MockTextConfig: + max_position_embeddings = 4096 + sliding_window = None + model_type = "qwen3" + num_attention_heads = 32 + + mock_text_config = MockTextConfig() + mock_hf_config.get_text_config.return_value = mock_text_config + mock_hf_config.dtype = "bfloat16" + mock_get_config.return_value = mock_hf_config + + load_config = LoadConfig(load_format="gguf") + loader = GGUFModelLoader(load_config) + + # Create ModelConfig with a valid repo_id to avoid validation errors + # Then test _prepare_weights with invalid format + model_config = ModelConfig(model="unsloth/Qwen3-0.6B") + # Manually set model to invalid format after creation + model_config.model = "invalid-format" + with pytest.raises(ValueError, match="Unrecognised GGUF reference"): + loader._prepare_weights(model_config) diff --git a/tests/transformers_utils/test_utils.py b/tests/transformers_utils/test_utils.py index bfe1cec76c13..a8d0b9be9ec2 100644 --- a/tests/transformers_utils/test_utils.py +++ b/tests/transformers_utils/test_utils.py @@ -1,11 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from pathlib import Path +from unittest.mock import patch +import pytest from vllm.transformers_utils.utils import ( is_cloud_storage, is_gcs, + is_gguf, + is_remote_gguf, is_s3, + split_remote_gguf, ) @@ -28,3 +34,143 @@ def test_is_cloud_storage(): assert is_cloud_storage("s3://model-path/path-to-model") assert not is_cloud_storage("/unix/local/path") assert not is_cloud_storage("nfs://nfs-fqdn.local") + + +class TestIsRemoteGGUF: + """Test is_remote_gguf utility function.""" + + def test_is_remote_gguf_with_colon_and_slash(self): + """Test is_remote_gguf with repo_id:quant_type format.""" + # Valid quant types + assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S") + assert is_remote_gguf("user/repo:Q2_K") + assert is_remote_gguf("repo/model:Q4_K") + assert is_remote_gguf("repo/model:Q8_0") + + # Invalid quant types should return False + assert not is_remote_gguf("repo/model:quant") + assert not is_remote_gguf("repo/model:INVALID") + assert not is_remote_gguf("repo/model:invalid_type") + + def test_is_remote_gguf_without_colon(self): + """Test is_remote_gguf without colon.""" + assert not is_remote_gguf("repo/model") + assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF") + + def test_is_remote_gguf_without_slash(self): + """Test is_remote_gguf without slash.""" + assert not is_remote_gguf("model.gguf") + # Even with valid quant_type, no slash means not remote GGUF + assert not is_remote_gguf("model:IQ1_S") + assert not is_remote_gguf("model:quant") + + def test_is_remote_gguf_local_path(self): + """Test is_remote_gguf with local file path.""" + assert not is_remote_gguf("/path/to/model.gguf") + assert not is_remote_gguf("./model.gguf") + + def test_is_remote_gguf_with_path_object(self): + """Test is_remote_gguf with Path object.""" + assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S")) + assert not is_remote_gguf(Path("repo/model")) + + def test_is_remote_gguf_with_http_https(self): + """Test is_remote_gguf with HTTP/HTTPS URLs.""" + # HTTP/HTTPS URLs should return False even with valid quant_type + assert not is_remote_gguf("http://example.com/repo/model:IQ1_S") + assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K") + assert not is_remote_gguf("http://repo/model:Q4_K") + assert not is_remote_gguf("https://repo/model:Q8_0") + + def test_is_remote_gguf_with_cloud_storage(self): + """Test is_remote_gguf with cloud storage paths.""" + # Cloud storage paths should return False even with valid quant_type + assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S") + assert not is_remote_gguf("gs://bucket/repo/model:Q2_K") + assert not is_remote_gguf("s3://repo/model:Q4_K") + assert not is_remote_gguf("gs://repo/model:Q8_0") + + +class TestSplitRemoteGGUF: + """Test split_remote_gguf utility function.""" + + def test_split_remote_gguf_valid(self): + """Test split_remote_gguf with valid repo_id:quant_type format.""" + repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S") + assert repo_id == "unsloth/Qwen3-0.6B-GGUF" + assert quant_type == "IQ1_S" + + repo_id, quant_type = split_remote_gguf("repo/model:Q2_K") + assert repo_id == "repo/model" + assert quant_type == "Q2_K" + + def test_split_remote_gguf_with_path_object(self): + """Test split_remote_gguf with Path object.""" + repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S")) + assert repo_id == "unsloth/Qwen3-0.6B-GGUF" + assert quant_type == "IQ1_S" + + def test_split_remote_gguf_invalid(self): + """Test split_remote_gguf with invalid format.""" + # Invalid format (no colon) - is_remote_gguf returns False + with pytest.raises(ValueError, match="Wrong GGUF model"): + split_remote_gguf("repo/model") + + # Invalid quant type - is_remote_gguf returns False + with pytest.raises(ValueError, match="Wrong GGUF model"): + split_remote_gguf("repo/model:INVALID_TYPE") + + # HTTP URL - is_remote_gguf returns False + with pytest.raises(ValueError, match="Wrong GGUF model"): + split_remote_gguf("http://repo/model:IQ1_S") + + # Cloud storage - is_remote_gguf returns False + with pytest.raises(ValueError, match="Wrong GGUF model"): + split_remote_gguf("s3://bucket/repo/model:Q2_K") + + +class TestIsGGUF: + """Test is_gguf utility function.""" + + @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True) + def test_is_gguf_with_local_file(self, mock_check_gguf): + """Test is_gguf with local GGUF file.""" + assert is_gguf("/path/to/model.gguf") + assert is_gguf("./model.gguf") + + def test_is_gguf_with_remote_gguf(self): + """Test is_gguf with remote GGUF format.""" + # Valid remote GGUF format (repo_id:quant_type with valid quant_type) + assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S") + assert is_gguf("repo/model:Q2_K") + assert is_gguf("repo/model:Q4_K") + + # Invalid quant_type should return False + assert not is_gguf("repo/model:quant") + assert not is_gguf("repo/model:INVALID") + + @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False) + def test_is_gguf_false(self, mock_check_gguf): + """Test is_gguf returns False for non-GGUF models.""" + assert not is_gguf("unsloth/Qwen3-0.6B") + assert not is_gguf("repo/model") + assert not is_gguf("model") + + def test_is_gguf_edge_cases(self): + """Test is_gguf with edge cases.""" + # Empty string + assert not is_gguf("") + + # Only colon, no slash (even with valid quant_type) + assert not is_gguf("model:IQ1_S") + + # Only slash, no colon + assert not is_gguf("repo/model") + + # HTTP/HTTPS URLs + assert not is_gguf("http://repo/model:IQ1_S") + assert not is_gguf("https://repo/model:Q2_K") + + # Cloud storage + assert not is_gguf("s3://bucket/repo/model:IQ1_S") + assert not is_gguf("gs://bucket/repo/model:Q2_K") diff --git a/vllm/config/model.py b/vllm/config/model.py index caa9a3440c41..14ffdec2e09d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import ( maybe_patch_hf_config_from_gguf, ) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri -from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect +from vllm.transformers_utils.utils import ( + is_gguf, + is_remote_gguf, + maybe_model_redirect, + split_remote_gguf, +) from vllm.utils.import_utils import LazyLoader from vllm.utils.torch_utils import common_broadcastable_dtype @@ -440,7 +445,8 @@ class ModelConfig: self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: - if check_gguf_file(self.model): + # Check if this is a GGUF model (either local file or remote GGUF) + if is_gguf(self.model): raise ValueError( "Using a tokenizer is mandatory when loading a GGUF model. " "Please specify the tokenizer path or name using the " @@ -832,7 +838,10 @@ class ModelConfig: self.tokenizer = object_storage_tokenizer.dir def _get_encoder_config(self): - return get_sentence_transformer_tokenizer_config(self.model, self.revision) + model = self.model + if is_remote_gguf(model): + model, _ = split_remote_gguf(model) + return get_sentence_transformer_tokenizer_config(model, self.revision) def _verify_tokenizer_mode(self) -> None: tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 177915715140..6d5b3392baa2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -86,7 +86,7 @@ from vllm.transformers_utils.config import ( is_interleaved, maybe_override_with_speculators, ) -from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage +from vllm.transformers_utils.utils import is_cloud_storage, is_gguf from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GiB_bytes from vllm.utils.network_utils import get_ip @@ -1148,8 +1148,8 @@ class EngineArgs: return engine_args def create_model_config(self) -> ModelConfig: - # gguf file needs a specific model loader and doesn't use hf_repo - if check_gguf_file(self.model): + # gguf file needs a specific model loader + if is_gguf(self.model): self.quantization = self.load_format = "gguf" # NOTE(woosuk): In V1, we use separate processes for workers (unless diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 2416836be03c..74052f72ceab 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, ) from vllm.model_executor.model_loader.weight_utils import ( + download_gguf, get_gguf_extra_tensor_names, get_gguf_weight_type_map, gguf_quant_weights_iterator, @@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader): f"load format {load_config.load_format}" ) - def _prepare_weights(self, model_name_or_path: str): + def _prepare_weights(self, model_config: ModelConfig): + model_name_or_path = model_config.model if os.path.isfile(model_name_or_path): return model_name_or_path # for raw HTTPS link @@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader): if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): repo_id, filename = model_name_or_path.rsplit("/", 1) return hf_hub_download(repo_id=repo_id, filename=filename) - else: - raise ValueError( - f"Unrecognised GGUF reference: {model_name_or_path} " - "(expected local file, raw URL, or /.gguf)" + # repo_id:quant_type + elif "/" in model_name_or_path and ":" in model_name_or_path: + repo_id, quant_type = model_name_or_path.rsplit(":", 1) + return download_gguf( + repo_id, + quant_type, + cache_dir=self.load_config.download_dir, + revision=model_config.revision, + ignore_patterns=self.load_config.ignore_patterns, ) + raise ValueError( + f"Unrecognised GGUF reference: {model_name_or_path} " + "(expected local file, raw URL, /.gguf, " + "or :)" + ) + def _get_gguf_weights_map(self, model_config: ModelConfig): """ GGUF uses this naming convention for their tensors from HF checkpoint: @@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader): gguf_to_hf_name_map: dict[str, str], ) -> dict[str, str]: weight_type_map = get_gguf_weight_type_map( - model_config.model, gguf_to_hf_name_map + model_name_or_path, gguf_to_hf_name_map ) is_multimodal = hasattr(model_config.hf_config, "vision_config") if is_multimodal: @@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader): yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model) + self._prepare_weights(model_config) def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - local_model_path = self._prepare_weights(model_config.model) + local_model_path = self._prepare_weights(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config) model.load_weights( self._get_weights_iterator(model_config, local_model_path, gguf_weights_map) @@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader): self, vllm_config: VllmConfig, model_config: ModelConfig ) -> nn.Module: device_config = vllm_config.device_config - local_model_path = self._prepare_weights(model_config.model) + local_model_path = self._prepare_weights(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights if "lm_head.weight" in get_gguf_extra_tensor_names( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 4572ebe2ea11..0809bdfa9d4c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -369,6 +369,52 @@ def get_sparse_attention_config( return config +def download_gguf( + repo_id: str, + quant_type: str, + cache_dir: str | None = None, + revision: str | None = None, + ignore_patterns: str | list[str] | None = None, +) -> str: + # Use patterns that snapshot_download can handle directly + # Patterns to match: + # - *-{quant_type}.gguf (root) + # - *-{quant_type}-*.gguf (root sharded) + # - */*-{quant_type}.gguf (subdir) + # - */*-{quant_type}-*.gguf (subdir sharded) + allow_patterns = [ + f"*-{quant_type}.gguf", + f"*-{quant_type}-*.gguf", + f"*/*-{quant_type}.gguf", + f"*/*-{quant_type}-*.gguf", + ] + + # Use download_weights_from_hf which handles caching and downloading + folder = download_weights_from_hf( + model_name_or_path=repo_id, + cache_dir=cache_dir, + allow_patterns=allow_patterns, + revision=revision, + ignore_patterns=ignore_patterns, + ) + + # Find the downloaded file(s) in the folder + local_files = [] + for pattern in allow_patterns: + # Convert pattern to glob pattern for local filesystem + glob_pattern = os.path.join(folder, pattern) + local_files.extend(glob.glob(glob_pattern)) + + if not local_files: + raise ValueError( + f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}" + ) + + # Sort to ensure consistent ordering (prefer non-sharded files) + local_files.sort(key=lambda x: (x.count("-"), x)) + return local_files[0] + + def download_weights_from_hf( model_name_or_path: str, cache_dir: str | None, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c1880a3fba0e..a29d92f67f5d 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -42,7 +42,10 @@ from vllm.logger import init_logger from vllm.transformers_utils.config_parser_base import ConfigParserBase from vllm.transformers_utils.utils import ( check_gguf_file, + is_gguf, + is_remote_gguf, parse_safetensors_file_metadata, + split_remote_gguf, ) if envs.VLLM_USE_MODELSCOPE: @@ -629,10 +632,12 @@ def maybe_override_with_speculators( Returns: Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ - is_gguf = check_gguf_file(model) - if is_gguf: + if check_gguf_file(model): kwargs["gguf_file"] = Path(model).name gguf_model_repo = Path(model).parent + elif is_remote_gguf(model): + repo_id, _ = split_remote_gguf(model) + gguf_model_repo = Path(repo_id) else: gguf_model_repo = None kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE @@ -678,10 +683,18 @@ def get_config( ) -> PretrainedConfig: # Separate model folder from file path for GGUF models - is_gguf = check_gguf_file(model) - if is_gguf: - kwargs["gguf_file"] = Path(model).name - model = Path(model).parent + _is_gguf = is_gguf(model) + _is_remote_gguf = is_remote_gguf(model) + if _is_gguf: + if check_gguf_file(model): + # Local GGUF file + kwargs["gguf_file"] = Path(model).name + model = Path(model).parent + elif _is_remote_gguf: + # Remote GGUF - extract repo_id from repo_id:quant_type format + # The actual GGUF file will be downloaded later by GGUFModelLoader + # Keep model as repo_id:quant_type for download, but use repo_id for config + model, _ = split_remote_gguf(model) if config_format == "auto": try: @@ -689,10 +702,25 @@ def get_config( # Transformers implementation. if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): config_format = "mistral" - elif is_gguf or file_or_path_exists( + elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists( model, HF_CONFIG_NAME, revision=revision ): config_format = "hf" + # Remote GGUF models must have config.json in repo, + # otherwise the config can't be parsed correctly. + # FIXME(Isotr0py): Support remote GGUF repos without config.json + elif _is_remote_gguf and not file_or_path_exists( + model, HF_CONFIG_NAME, revision=revision + ): + err_msg = ( + "Could not find config.json for remote GGUF model repo. " + "To load remote GGUF model through `:`, " + "ensure your model has config.json (HF format) file. " + "Otherwise please specify --hf-config-path " + "in engine args to fetch config from unquantized hf model." + ) + logger.error(err_msg) + raise ValueError(err_msg) else: raise ValueError( "Could not detect config format for no config file found. " @@ -713,9 +741,6 @@ def get_config( "'config.json'.\n" " - For Mistral models: ensure the presence of a " "'params.json'.\n" - "3. For GGUF: pass the local path of the GGUF checkpoint.\n" - " Loading GGUF from a remote repo directly is not yet " - "supported.\n" ).format(model=model) raise ValueError(error_message) from e @@ -729,7 +754,7 @@ def get_config( **kwargs, ) # Special architecture mapping check for GGUF models - if is_gguf: + if _is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] @@ -889,6 +914,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None A dictionary containing the pooling type and whether normalization is used, or None if no pooling configuration is found. """ + if is_remote_gguf(model): + model, _ = split_remote_gguf(model) modules_file_name = "modules.json" @@ -1108,6 +1135,8 @@ def get_hf_image_processor_config( # Separate model folder from file path for GGUF models if check_gguf_file(model): model = Path(model).parent + elif is_remote_gguf(model): + model, _ = split_remote_gguf(model) return get_image_processor_config( model, token=hf_token, revision=revision, **kwargs ) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 8deacb5b0791..63cdf6337034 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin from transformers.video_processing_utils import BaseVideoProcessor from typing_extensions import TypeVar -from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path +from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf from vllm.utils.func_utils import get_allowed_kwarg_only_overrides if TYPE_CHECKING: @@ -236,8 +236,8 @@ def cached_processor_from_config( processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, **kwargs: Any, ) -> _P: - if check_gguf_file(model_config.model): - assert not check_gguf_file(model_config.tokenizer), ( + if is_gguf(model_config.model): + assert not is_gguf(model_config.tokenizer), ( "For multimodal GGUF models, the original tokenizer " "should be used to correctly load processor." ) @@ -350,8 +350,8 @@ def cached_image_processor_from_config( model_config: "ModelConfig", **kwargs: Any, ): - if check_gguf_file(model_config.model): - assert not check_gguf_file(model_config.tokenizer), ( + if is_gguf(model_config.model): + assert not is_gguf(model_config.tokenizer), ( "For multimodal GGUF models, the original tokenizer " "should be used to correctly load image processor." ) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 233076741503..f0e0ba8ef424 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -20,7 +20,12 @@ from vllm.transformers_utils.config import ( list_filtered_repo_files, ) from vllm.transformers_utils.tokenizers import MistralTokenizer -from vllm.transformers_utils.utils import check_gguf_file +from vllm.transformers_utils.utils import ( + check_gguf_file, + is_gguf, + is_remote_gguf, + split_remote_gguf, +) if TYPE_CHECKING: from vllm.config import ModelConfig @@ -180,10 +185,12 @@ def get_tokenizer( kwargs["truncation_side"] = "left" # Separate model folder from file path for GGUF models - is_gguf = check_gguf_file(tokenizer_name) - if is_gguf: - kwargs["gguf_file"] = Path(tokenizer_name).name - tokenizer_name = Path(tokenizer_name).parent + if is_gguf(tokenizer_name): + if check_gguf_file(tokenizer_name): + kwargs["gguf_file"] = Path(tokenizer_name).name + tokenizer_name = Path(tokenizer_name).parent + elif is_remote_gguf(tokenizer_name): + tokenizer_name, _ = split_remote_gguf(tokenizer_name) # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format # first to use official Mistral tokenizer if possible. diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 901a64d9d263..45a873c9f700 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -9,6 +9,8 @@ from os import PathLike from pathlib import Path from typing import Any +from gguf import GGMLQuantizationType + import vllm.envs as envs from vllm.logger import init_logger @@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool: return False +@cache +def is_remote_gguf(model: str | Path) -> bool: + """Check if the model is a remote GGUF model.""" + model = str(model) + return ( + (not is_cloud_storage(model)) + and (not model.startswith(("http://", "https://"))) + and ("/" in model and ":" in model) + and is_valid_gguf_quant_type(model.rsplit(":", 1)[1]) + ) + + +def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool: + """Check if the quant type is a valid GGUF quant type.""" + return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None + + +def split_remote_gguf(model: str | Path) -> tuple[str, str]: + """Split the model into repo_id and quant type.""" + model = str(model) + if is_remote_gguf(model): + parts = model.rsplit(":", 1) + return (parts[0], parts[1]) + raise ValueError( + "Wrong GGUF model or invalid GGUF quant type: %s.\n" + "- It should be in repo_id:quant_type format.\n" + "- Valid GGMLQuantizationType values: %s", + model, + GGMLQuantizationType._member_names_, + ) + + +def is_gguf(model: str | Path) -> bool: + """Check if the model is a GGUF model. + + Args: + model: Model name, path, or Path object to check. + + Returns: + True if the model is a GGUF model, False otherwise. + """ + model = str(model) + + # Check if it's a local GGUF file + if check_gguf_file(model): + return True + + # Check if it's a remote GGUF model (repo_id:quant_type format) + return is_remote_gguf(model) + + def modelscope_list_repo_files( repo_id: str, revision: str | None = None,