mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:37:13 +08:00
Move LoadConfig from config/__init__.py to config/load.py (#24566)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
9e3c3a7df2
commit
f36355abfd
@ -6,9 +6,9 @@ import random
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
import pytest
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader import (get_model_loader,
|
||||
register_model_loader)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LoadConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
|
||||
load_format = "runai_streamer"
|
||||
|
||||
@ -6,8 +6,9 @@ from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
||||
get_field, update_config)
|
||||
from vllm.config import (ModelConfig, PoolerConfig, VllmConfig, get_field,
|
||||
update_config)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -12,9 +12,10 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
@ -64,8 +65,6 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
HfOverrides = Union[dict, Callable[[type], type]]
|
||||
@ -75,8 +74,6 @@ else:
|
||||
QuantizationConfig = Any
|
||||
QuantizationMethods = Any
|
||||
BaseModelLoader = Any
|
||||
LoadFormats = Any
|
||||
TensorizerConfig = Any
|
||||
LogitsProcessor = Any
|
||||
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
|
||||
|
||||
@ -1801,90 +1798,6 @@ class ModelConfig:
|
||||
return max_model_len
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: Union[str, LoadFormats] = "auto"
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
- "pt" will load the weights in the pytorch bin format.\n
|
||||
- "safetensors" will load the weights in the safetensors format.\n
|
||||
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||
to speed up the loading.\n
|
||||
- "dummy" will initialize the weights with random values, which is mainly
|
||||
for profiling.\n
|
||||
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||
more information.\n
|
||||
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||
Streamer.\n
|
||||
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.\n
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models.
|
||||
- Other custom values can be supported via plugins."""
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format."""
|
||||
device: Optional[str] = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
|
||||
"""
|
||||
pt_load_map_location: the map location for loading pytorch checkpoint, to
|
||||
support loading checkpoints can only be loaded on certain devices like
|
||||
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
|
||||
mapping from different devices like from GPU 1 to GPU 0:
|
||||
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
|
||||
in dictionary needs to be double quoted for json parsing. For more details,
|
||||
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
self.load_format = self.load_format.lower()
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
|
||||
104
vllm/config/load.py
Normal file
104
vllm/config/load.py
Normal file
@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
else:
|
||||
LoadFormats = Any
|
||||
TensorizerConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: Union[str, LoadFormats] = "auto"
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
- "pt" will load the weights in the pytorch bin format.\n
|
||||
- "safetensors" will load the weights in the safetensors format.\n
|
||||
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||
to speed up the loading.\n
|
||||
- "dummy" will initialize the weights with random values, which is mainly
|
||||
for profiling.\n
|
||||
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||
more information.\n
|
||||
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||
Streamer.\n
|
||||
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.\n
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models.
|
||||
- Other custom values can be supported via plugins."""
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format."""
|
||||
device: Optional[str] = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
|
||||
"""
|
||||
pt_load_map_location: the map location for loading pytorch checkpoint, to
|
||||
support loading checkpoints can only be loaded on certain devices like
|
||||
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
|
||||
mapping from different devices like from GPU 1 to GPU 0:
|
||||
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
|
||||
in dictionary needs to be double quoted for json parsing. For more details,
|
||||
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
self.load_format = self.load_format.lower()
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
@ -5,7 +5,8 @@ from typing import Literal, Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||
@ -67,7 +68,7 @@ def register_model_loader(load_format: str):
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config import LoadConfig
|
||||
>>> from vllm.config.load import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
|
||||
@ -5,7 +5,8 @@ from abc import ABC, abstractmethod
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
@ -16,7 +16,8 @@ from packaging import version
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
|
||||
@ -11,7 +11,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
initialize_dummy_weights)
|
||||
|
||||
@ -9,7 +9,8 @@ import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
@ -9,7 +9,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
|
||||
@ -10,7 +10,8 @@ from typing import Any, Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
|
||||
@ -8,7 +8,8 @@ from typing import Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
|
||||
@ -23,7 +23,8 @@ from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user