mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 07:37:03 +08:00
[Core] Support model loader plugins (#21067)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
f0f4de8f26
commit
610852a423
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import LoadFormat
|
|
||||||
|
|
||||||
test_model = "openai-community/gpt2"
|
test_model = "openai-community/gpt2"
|
||||||
|
|
||||||
@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
|||||||
|
|
||||||
|
|
||||||
def test_model_loader_download_files(vllm_runner):
|
def test_model_loader_download_files(vllm_runner):
|
||||||
with vllm_runner(test_model,
|
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
|
||||||
load_format=LoadFormat.FASTSAFETENSORS) as llm:
|
|
||||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||||
assert deserialized_outputs
|
assert deserialized_outputs
|
||||||
|
|||||||
0
tests/model_executor/model_loader/__init__.py
Normal file
0
tests/model_executor/model_loader/__init__.py
Normal file
37
tests/model_executor/model_loader/test_registry.py
Normal file
37
tests/model_executor/model_loader/test_registry.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import LoadConfig, ModelConfig
|
||||||
|
from vllm.model_executor.model_loader import (get_model_loader,
|
||||||
|
register_model_loader)
|
||||||
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
@register_model_loader("custom_load_format")
|
||||||
|
class CustomModelLoader(BaseModelLoader):
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig) -> None:
|
||||||
|
super().__init__(load_config)
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_weights(self, model: nn.Module,
|
||||||
|
model_config: ModelConfig) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_model_loader():
|
||||||
|
load_config = LoadConfig(load_format="custom_load_format")
|
||||||
|
assert isinstance(get_model_loader(load_config), CustomModelLoader)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_model_loader():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
|
||||||
|
@register_model_loader("invalid_load_format")
|
||||||
|
class InValidModelLoader:
|
||||||
|
pass
|
||||||
@ -2,9 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import LoadConfig, LoadFormat
|
from vllm.config import LoadConfig
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
|
|
||||||
|
load_format = "runai_streamer"
|
||||||
test_model = "openai-community/gpt2"
|
test_model = "openai-community/gpt2"
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
|||||||
|
|
||||||
|
|
||||||
def get_runai_model_loader():
|
def get_runai_model_loader():
|
||||||
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
|
load_config = LoadConfig(load_format=load_format)
|
||||||
return get_model_loader(load_config)
|
return get_model_loader(load_config)
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
|
|||||||
|
|
||||||
|
|
||||||
def test_runai_model_loader_download_files(vllm_runner):
|
def test_runai_model_loader_download_files(vllm_runner):
|
||||||
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
|
with vllm_runner(test_model, load_format=load_format) as llm:
|
||||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||||
assert deserialized_outputs
|
assert deserialized_outputs
|
||||||
|
|||||||
@ -65,7 +65,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.model_loader import BaseModelLoader
|
from vllm.model_executor.model_loader import LoadFormats
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
|
|
||||||
ConfigType = type[DataclassInstance]
|
ConfigType = type[DataclassInstance]
|
||||||
@ -78,6 +78,7 @@ else:
|
|||||||
QuantizationConfig = Any
|
QuantizationConfig = Any
|
||||||
QuantizationMethods = Any
|
QuantizationMethods = Any
|
||||||
BaseModelLoader = Any
|
BaseModelLoader = Any
|
||||||
|
LoadFormats = Any
|
||||||
TensorizerConfig = Any
|
TensorizerConfig = Any
|
||||||
ConfigType = type
|
ConfigType = type
|
||||||
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
|
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
|
||||||
@ -1773,29 +1774,12 @@ class CacheConfig:
|
|||||||
logger.warning("Possibly too large swap space. %s", msg)
|
logger.warning("Possibly too large swap space. %s", msg)
|
||||||
|
|
||||||
|
|
||||||
class LoadFormat(str, enum.Enum):
|
|
||||||
AUTO = "auto"
|
|
||||||
PT = "pt"
|
|
||||||
SAFETENSORS = "safetensors"
|
|
||||||
NPCACHE = "npcache"
|
|
||||||
DUMMY = "dummy"
|
|
||||||
TENSORIZER = "tensorizer"
|
|
||||||
SHARDED_STATE = "sharded_state"
|
|
||||||
GGUF = "gguf"
|
|
||||||
BITSANDBYTES = "bitsandbytes"
|
|
||||||
MISTRAL = "mistral"
|
|
||||||
RUNAI_STREAMER = "runai_streamer"
|
|
||||||
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
|
|
||||||
FASTSAFETENSORS = "fastsafetensors"
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadConfig:
|
class LoadConfig:
|
||||||
"""Configuration for loading the model weights."""
|
"""Configuration for loading the model weights."""
|
||||||
|
|
||||||
load_format: Union[str, LoadFormat,
|
load_format: Union[str, LoadFormats] = "auto"
|
||||||
"BaseModelLoader"] = LoadFormat.AUTO.value
|
|
||||||
"""The format of the model weights to load:\n
|
"""The format of the model weights to load:\n
|
||||||
- "auto" will try to load the weights in the safetensors format and fall
|
- "auto" will try to load the weights in the safetensors format and fall
|
||||||
back to the pytorch bin format if safetensors format is not available.\n
|
back to the pytorch bin format if safetensors format is not available.\n
|
||||||
@ -1816,7 +1800,8 @@ class LoadConfig:
|
|||||||
- "gguf" will load weights from GGUF format files (details specified in
|
- "gguf" will load weights from GGUF format files (details specified in
|
||||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||||
- "mistral" will load weights from consolidated safetensors files used by
|
- "mistral" will load weights from consolidated safetensors files used by
|
||||||
Mistral models."""
|
Mistral models.
|
||||||
|
- Other custom values can be supported via plugins."""
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
"""Directory to download and load the weights, default to the default
|
"""Directory to download and load the weights, default to the default
|
||||||
cache directory of Hugging Face."""
|
cache directory of Hugging Face."""
|
||||||
@ -1864,10 +1849,7 @@ class LoadConfig:
|
|||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.load_format, str):
|
self.load_format = self.load_format.lower()
|
||||||
load_format = self.load_format.lower()
|
|
||||||
self.load_format = LoadFormat(load_format)
|
|
||||||
|
|
||||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Ignoring the following patterns when downloading weights: %s",
|
"Ignoring the following patterns when downloading weights: %s",
|
||||||
|
|||||||
@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
|||||||
DetailedTraceModules, Device, DeviceConfig,
|
DetailedTraceModules, Device, DeviceConfig,
|
||||||
DistributedExecutorBackend, GuidedDecodingBackend,
|
DistributedExecutorBackend, GuidedDecodingBackend,
|
||||||
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
||||||
KVTransferConfig, LoadConfig, LoadFormat,
|
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||||
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
|
||||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
TokenizerMode, VllmConfig, get_attr_docs, get_field)
|
||||||
get_field)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
|
from vllm.model_executor.model_loader import LoadFormats
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
else:
|
else:
|
||||||
ExecutorBase = Any
|
ExecutorBase = Any
|
||||||
QuantizationMethods = Any
|
QuantizationMethods = Any
|
||||||
|
LoadFormats = Any
|
||||||
UsageContext = Any
|
UsageContext = Any
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -276,7 +277,7 @@ class EngineArgs:
|
|||||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||||
download_dir: Optional[str] = LoadConfig.download_dir
|
download_dir: Optional[str] = LoadConfig.download_dir
|
||||||
load_format: str = LoadConfig.load_format
|
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||||
config_format: str = ModelConfig.config_format
|
config_format: str = ModelConfig.config_format
|
||||||
dtype: ModelDType = ModelConfig.dtype
|
dtype: ModelDType = ModelConfig.dtype
|
||||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||||
@ -547,9 +548,7 @@ class EngineArgs:
|
|||||||
title="LoadConfig",
|
title="LoadConfig",
|
||||||
description=LoadConfig.__doc__,
|
description=LoadConfig.__doc__,
|
||||||
)
|
)
|
||||||
load_group.add_argument("--load-format",
|
load_group.add_argument("--load-format", **load_kwargs["load_format"])
|
||||||
choices=[f.value for f in LoadFormat],
|
|
||||||
**load_kwargs["load_format"])
|
|
||||||
load_group.add_argument("--download-dir",
|
load_group.add_argument("--download-dir",
|
||||||
**load_kwargs["download_dir"])
|
**load_kwargs["download_dir"])
|
||||||
load_group.add_argument("--model-loader-extra-config",
|
load_group.add_argument("--model-loader-extra-config",
|
||||||
@ -864,10 +863,9 @@ class EngineArgs:
|
|||||||
|
|
||||||
# NOTE: This is to allow model loading from S3 in CI
|
# NOTE: This is to allow model loading from S3 in CI
|
||||||
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
|
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
|
||||||
and self.model in MODELS_ON_S3
|
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
|
||||||
and self.load_format == LoadFormat.AUTO): # noqa: E501
|
|
||||||
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
||||||
self.load_format = LoadFormat.RUNAI_STREAMER
|
self.load_format = "runai_streamer"
|
||||||
|
|
||||||
return ModelConfig(
|
return ModelConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@ -1299,7 +1297,7 @@ class EngineArgs:
|
|||||||
#############################################################
|
#############################################################
|
||||||
# Unsupported Feature Flags on V1.
|
# Unsupported Feature Flags on V1.
|
||||||
|
|
||||||
if self.load_format == LoadFormat.SHARDED_STATE.value:
|
if self.load_format == "sharded_state":
|
||||||
_raise_or_fallback(
|
_raise_or_fallback(
|
||||||
feature_name=f"--load_format {self.load_format}",
|
feature_name=f"--load_format {self.load_format}",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
|
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||||
BitsAndBytesModelLoader)
|
BitsAndBytesModelLoader)
|
||||||
@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
|||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
get_architecture_class_name, get_model_architecture, get_model_cls)
|
get_architecture_class_name, get_model_architecture, get_model_cls)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Reminder: Please update docstring in `LoadConfig`
|
||||||
|
# if a new load format is added here
|
||||||
|
LoadFormats = Literal[
|
||||||
|
"auto",
|
||||||
|
"bitsandbytes",
|
||||||
|
"dummy",
|
||||||
|
"fastsafetensors",
|
||||||
|
"gguf",
|
||||||
|
"mistral",
|
||||||
|
"npcache",
|
||||||
|
"pt",
|
||||||
|
"runai_streamer",
|
||||||
|
"runai_streamer_sharded",
|
||||||
|
"safetensors",
|
||||||
|
"sharded_state",
|
||||||
|
"tensorizer",
|
||||||
|
]
|
||||||
|
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||||
|
"auto": DefaultModelLoader,
|
||||||
|
"bitsandbytes": BitsAndBytesModelLoader,
|
||||||
|
"dummy": DummyModelLoader,
|
||||||
|
"fastsafetensors": DefaultModelLoader,
|
||||||
|
"gguf": GGUFModelLoader,
|
||||||
|
"mistral": DefaultModelLoader,
|
||||||
|
"npcache": DefaultModelLoader,
|
||||||
|
"pt": DefaultModelLoader,
|
||||||
|
"runai_streamer": RunaiModelStreamerLoader,
|
||||||
|
"runai_streamer_sharded": ShardedStateLoader,
|
||||||
|
"safetensors": DefaultModelLoader,
|
||||||
|
"sharded_state": ShardedStateLoader,
|
||||||
|
"tensorizer": TensorizerLoader,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_loader(load_format: str):
|
||||||
|
"""Register a customized vllm model loader.
|
||||||
|
|
||||||
|
When a load format is not supported by vllm, you can register a customized
|
||||||
|
model loader to support it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_format (str): The model loader format name.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from vllm.config import LoadConfig
|
||||||
|
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
|
||||||
|
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
|
>>>
|
||||||
|
>>> @register_model_loader("my_loader")
|
||||||
|
... class MyModelLoader(BaseModelLoader):
|
||||||
|
... def download_model(self):
|
||||||
|
... pass
|
||||||
|
...
|
||||||
|
... def load_weights(self):
|
||||||
|
... pass
|
||||||
|
>>>
|
||||||
|
>>> load_config = LoadConfig(load_format="my_loader")
|
||||||
|
>>> type(get_model_loader(load_config))
|
||||||
|
<class 'MyModelLoader'>
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def _wrapper(model_loader_cls):
|
||||||
|
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||||
|
logger.warning(
|
||||||
|
"Load format `%s` is already registered, and will be "
|
||||||
|
"overwritten by the new loader class `%s`.", load_format,
|
||||||
|
model_loader_cls)
|
||||||
|
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||||
|
raise ValueError("The model loader must be a subclass of "
|
||||||
|
"`BaseModelLoader`.")
|
||||||
|
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||||
|
logger.info("Registered model loader `%s` with load format `%s`",
|
||||||
|
model_loader_cls, load_format)
|
||||||
|
return model_loader_cls
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||||
"""Get a model loader based on the load format."""
|
"""Get a model loader based on the load format."""
|
||||||
if isinstance(load_config.load_format, type):
|
load_format = load_config.load_format
|
||||||
return load_config.load_format(load_config)
|
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||||
|
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||||
if load_config.load_format == LoadFormat.DUMMY:
|
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||||
return DummyModelLoader(load_config)
|
|
||||||
|
|
||||||
if load_config.load_format == LoadFormat.TENSORIZER:
|
|
||||||
return TensorizerLoader(load_config)
|
|
||||||
|
|
||||||
if load_config.load_format == LoadFormat.SHARDED_STATE:
|
|
||||||
return ShardedStateLoader(load_config)
|
|
||||||
|
|
||||||
if load_config.load_format == LoadFormat.BITSANDBYTES:
|
|
||||||
return BitsAndBytesModelLoader(load_config)
|
|
||||||
|
|
||||||
if load_config.load_format == LoadFormat.GGUF:
|
|
||||||
return GGUFModelLoader(load_config)
|
|
||||||
|
|
||||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
|
||||||
return RunaiModelStreamerLoader(load_config)
|
|
||||||
|
|
||||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
|
|
||||||
return ShardedStateLoader(load_config, runai_model_streamer=True)
|
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(*,
|
def get_model(*,
|
||||||
@ -66,6 +125,7 @@ __all__ = [
|
|||||||
"get_architecture_class_name",
|
"get_architecture_class_name",
|
||||||
"get_model_architecture",
|
"get_model_architecture",
|
||||||
"get_model_cls",
|
"get_model_cls",
|
||||||
|
"register_model_loader",
|
||||||
"BaseModelLoader",
|
"BaseModelLoader",
|
||||||
"BitsAndBytesModelLoader",
|
"BitsAndBytesModelLoader",
|
||||||
"GGUFModelLoader",
|
"GGUFModelLoader",
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch import nn
|
|||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig
|
from vllm.config import LoadConfig, ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||||
# Some quantized models use .pt files for storing the weights.
|
# Some quantized models use .pt files for storing the weights.
|
||||||
if load_format == LoadFormat.AUTO:
|
if load_format == "auto":
|
||||||
allow_patterns = ["*.safetensors", "*.bin"]
|
allow_patterns = ["*.safetensors", "*.bin"]
|
||||||
elif (load_format == LoadFormat.SAFETENSORS
|
elif (load_format == "safetensors"
|
||||||
or load_format == LoadFormat.FASTSAFETENSORS):
|
or load_format == "fastsafetensors"):
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
allow_patterns = ["*.safetensors"]
|
allow_patterns = ["*.safetensors"]
|
||||||
elif load_format == LoadFormat.MISTRAL:
|
elif load_format == "mistral":
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
allow_patterns = ["consolidated*.safetensors"]
|
allow_patterns = ["consolidated*.safetensors"]
|
||||||
index_file = "consolidated.safetensors.index.json"
|
index_file = "consolidated.safetensors.index.json"
|
||||||
elif load_format == LoadFormat.PT:
|
elif load_format == "pt":
|
||||||
allow_patterns = ["*.pt"]
|
allow_patterns = ["*.pt"]
|
||||||
elif load_format == LoadFormat.NPCACHE:
|
elif load_format == "npcache":
|
||||||
allow_patterns = ["*.bin"]
|
allow_patterns = ["*.bin"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown load_format: {load_format}")
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||||
source.allow_patterns_overrides)
|
source.allow_patterns_overrides)
|
||||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
if self.load_config.load_format == "npcache":
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
assert use_safetensors is False
|
assert use_safetensors is False
|
||||||
weights_iterator = np_cache_weights_iterator(
|
weights_iterator = np_cache_weights_iterator(
|
||||||
@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
)
|
)
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
if self.load_config.load_format == "fastsafetensors":
|
||||||
weights_iterator = fastsafetensors_weights_iterator(
|
weights_iterator = fastsafetensors_weights_iterator(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
self.load_config.use_tqdm_on_load,
|
self.load_config.use_tqdm_on_load,
|
||||||
|
|||||||
@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
|
|
||||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, load_config: LoadConfig):
|
||||||
load_config: LoadConfig,
|
|
||||||
runai_model_streamer: bool = False):
|
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
|
|
||||||
self.runai_model_streamer = runai_model_streamer
|
|
||||||
extra_config = ({} if load_config.model_loader_extra_config is None
|
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||||
else load_config.model_loader_extra_config.copy())
|
else load_config.model_loader_extra_config.copy())
|
||||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||||
@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def iterate_over_files(
|
def iterate_over_files(
|
||||||
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
|
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||||
if self.runai_model_streamer:
|
if self.load_config.load_format == "runai_streamer_sharded":
|
||||||
yield from runai_safetensors_weights_iterator(paths, True)
|
yield from runai_safetensors_weights_iterator(paths, True)
|
||||||
else:
|
else:
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user