From f36355abfd1d87d9e89fcc881fc99f3df72eeb9b Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:14:18 +0100 Subject: [PATCH] Move `LoadConfig` from `config/__init__.py` to `config/load.py` (#24566) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/lora/test_worker.py | 6 +- .../model_loader/test_registry.py | 3 +- .../test_runai_model_streamer_loader.py | 2 +- tests/test_config.py | 5 +- tests/v1/spec_decode/test_eagle.py | 3 +- vllm/config/__init__.py | 89 +-------------- vllm/config/load.py | 104 ++++++++++++++++++ vllm/model_executor/model_loader/__init__.py | 5 +- .../model_loader/base_loader.py | 3 +- .../model_loader/bitsandbytes_loader.py | 3 +- .../model_loader/default_loader.py | 3 +- .../model_loader/dummy_loader.py | 3 +- .../model_loader/gguf_loader.py | 3 +- .../model_loader/runai_streamer_loader.py | 3 +- .../model_loader/sharded_state_loader.py | 3 +- .../model_loader/tensorizer_loader.py | 3 +- .../model_loader/weight_utils.py | 3 +- 17 files changed, 137 insertions(+), 107 deletions(-) create mode 100644 vllm/config/load.py diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index a836ff94ba3ed..02bfe0bf914a9 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -6,9 +6,9 @@ import random import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config.load import LoadConfig from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py index 93a3e34835b5a..639ee6db9270f 100644 --- a/tests/model_executor/model_loader/test_registry.py +++ b/tests/model_executor/model_loader/test_registry.py @@ -4,7 +4,8 @@ import pytest from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import (get_model_loader, register_model_loader) from vllm.model_executor.model_loader.base_loader import BaseModelLoader diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index 84c615b6b8dbc..22bdb3b44eb03 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import get_model_loader load_format = "runai_streamer" diff --git a/tests/test_config.py b/tests/test_config.py index 957771a4226bc..373fbd267539a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,8 +6,9 @@ from dataclasses import MISSING, Field, asdict, dataclass, field import pytest from vllm.compilation.backends import VllmBackend -from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - get_field, update_config) +from vllm.config import (ModelConfig, PoolerConfig, VllmConfig, get_field, + update_config) +from vllm.config.load import LoadConfig from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 46e3a611c6d26..ddedc61aae296 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -12,9 +12,10 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 4bab06a98cb21..c8d531f12a2e4 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -34,6 +34,7 @@ from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig +from vllm.config.load import LoadConfig from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy @@ -64,8 +65,6 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) - from vllm.model_executor.model_loader import LoadFormats - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.sample.logits_processor import LogitsProcessor HfOverrides = Union[dict, Callable[[type], type]] @@ -75,8 +74,6 @@ else: QuantizationConfig = Any QuantizationMethods = Any BaseModelLoader = Any - LoadFormats = Any - TensorizerConfig = Any LogitsProcessor = Any HfOverrides = Union[dict[str, Any], Callable[[type], type]] @@ -1801,90 +1798,6 @@ class ModelConfig: return max_model_len -@config -@dataclass -class LoadConfig: - """Configuration for loading the model weights.""" - - load_format: Union[str, LoadFormats] = "auto" - """The format of the model weights to load:\n - - "auto" will try to load the weights in the safetensors format and fall - back to the pytorch bin format if safetensors format is not available.\n - - "pt" will load the weights in the pytorch bin format.\n - - "safetensors" will load the weights in the safetensors format.\n - - "npcache" will load the weights in pytorch format and store a numpy cache - to speed up the loading.\n - - "dummy" will initialize the weights with random values, which is mainly - for profiling.\n - - "tensorizer" will use CoreWeave's tensorizer library for fast weight - loading. See the Tensorize vLLM Model script in the Examples section for - more information.\n - - "runai_streamer" will load the Safetensors weights using Run:ai Model - Streamer.\n - - "bitsandbytes" will load the weights using bitsandbytes quantization.\n - - "sharded_state" will load weights from pre-sharded checkpoint files, - supporting efficient loading of tensor-parallel models.\n - - "gguf" will load weights from GGUF format files (details specified in - https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n - - "mistral" will load weights from consolidated safetensors files used by - Mistral models. - - Other custom values can be supported via plugins.""" - download_dir: Optional[str] = None - """Directory to download and load the weights, default to the default - cache directory of Hugging Face.""" - model_loader_extra_config: Union[dict, TensorizerConfig] = field( - default_factory=dict) - """Extra config for model loader. This will be passed to the model loader - corresponding to the chosen load_format.""" - device: Optional[str] = None - """Device to which model weights will be loaded, default to - device_config.device""" - ignore_patterns: Optional[Union[list[str], str]] = None - """The list of patterns to ignore when loading the model. Default to - "original/**/*" to avoid repeated loading of llama's checkpoints.""" - use_tqdm_on_load: bool = True - """Whether to enable tqdm for showing progress bar when loading model - weights.""" - pt_load_map_location: Union[str, dict[str, str]] = "cpu" - """ - pt_load_map_location: the map location for loading pytorch checkpoint, to - support loading checkpoints can only be loaded on certain devices like - "cuda", this is equivalent to {"": "cuda"}. Another supported format is - mapping from different devices like from GPU 1 to GPU 0: - {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings - in dictionary needs to be double quoted for json parsing. For more details, - see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - self.load_format = self.load_format.lower() - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: - logger.info( - "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns) - else: - self.ignore_patterns = ["original/**/*"] - - Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] diff --git a/vllm/config/load.py b/vllm/config/load.py new file mode 100644 index 0000000000000..e4999e36b49bf --- /dev/null +++ b/vllm/config/load.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.model_executor.model_loader import LoadFormats + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +else: + LoadFormats = Any + TensorizerConfig = Any + +logger = init_logger(__name__) + + +@config +@dataclass +class LoadConfig: + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormats] = "auto" + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models. + - Other custom values can be supported via plugins.""" + download_dir: Optional[str] = None + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + model_loader_extra_config: Union[dict, TensorizerConfig] = field( + default_factory=dict) + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format.""" + device: Optional[str] = None + """Device to which model weights will be loaded, default to + device_config.device""" + ignore_patterns: Optional[Union[list[str], str]] = None + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" + use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + self.load_format = self.load_format.lower() + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3e..138a2ff30b622 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -5,7 +5,8 @@ from typing import Literal, Optional from torch import nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( @@ -67,7 +68,7 @@ def register_model_loader(load_format: str): load_format (str): The model loader format name. Examples: - >>> from vllm.config import LoadConfig + >>> from vllm.config.load import LoadConfig >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader >>> diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960d..ab538a3c95620 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -5,7 +5,8 @@ from abc import ABC, abstractmethod import torch import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index c8dd1ec0ec3c6..9c34159f9a269 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -16,7 +16,8 @@ from packaging import version from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 4badc31753445..f883e1e739102 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -11,7 +11,8 @@ import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index f4a7da5744e04..5b8c6268f64ef 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( initialize_dummy_weights) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06e..aaee8f3f76353 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -9,7 +9,8 @@ import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 19e73efc91077..dc941401a04e0 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -9,7 +9,8 @@ import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 3edd4ec4007e8..a85ca065d1d27 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -10,7 +10,8 @@ from typing import Any, Optional import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4cee..65ea49c642944 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -8,7 +8,8 @@ from typing import Union import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a4eda36148d7a..0de8dbbca9c7f 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -23,7 +23,8 @@ from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm from vllm import envs -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig,