diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index aa91fa8e1c1c6..8b96184f579e4 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,8 +2,7 @@ from vllm import SamplingParams from vllm.config import LoadConfig, LoadFormat -from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader, - get_model_loader) +from vllm.model_executor.model_loader import get_model_loader test_model = "openai-community/gpt2" @@ -24,7 +23,7 @@ def get_runai_model_loader(): def test_get_model_loader_with_runai_flag(): model_loader = get_runai_model_loader() - assert isinstance(model_loader, RunaiModelStreamerLoader) + assert model_loader.__class__.__name__ == "RunaiModelStreamerLoader" def test_runai_model_loader_download_files(vllm_runner): diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 94b0156e104b2..77fec0968000f 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -10,7 +10,7 @@ import torch from huggingface_hub import snapshot_download from vllm import LLM, SamplingParams -from vllm.model_executor.model_loader.loader import ShardedStateLoader +from vllm.model_executor.model_loader import ShardedStateLoader prompts = [ "Hello, my name is", diff --git a/tests/utils.py b/tests/utils.py index 0983687e2ce9a..bf38d7843853d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,7 +29,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import (FlexibleArgumentParser, GB_bytes, diff --git a/vllm/config.py b/vllm/config.py index 40beace3040c5..23f69c6123692 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,7 +54,7 @@ if TYPE_CHECKING: from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) - from vllm.model_executor.model_loader.loader import BaseModelLoader + from vllm.model_executor.model_loader import BaseModelLoader ConfigType = type[DataclassInstance] else: diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 9048c70c7a714..92a0b0923b6e0 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -2,19 +2,67 @@ from torch import nn -from vllm.config import VllmConfig -from vllm.model_executor.model_loader.loader import (BaseModelLoader, - get_model_loader) +from vllm.config import LoadConfig, LoadFormat, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.bitsandbytes_loader import ( + BitsAndBytesModelLoader) +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader +from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader +from vllm.model_executor.model_loader.runai_streamer_loader import ( + RunaiModelStreamerLoader) +from vllm.model_executor.model_loader.sharded_state_loader import ( + ShardedStateLoader) +from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: + return ShardedStateLoader(load_config, runai_model_streamer=True) + + return DefaultModelLoader(load_config) + + def get_model(*, vllm_config: VllmConfig) -> nn.Module: loader = get_model_loader(vllm_config.load_config) return loader.load_model(vllm_config=vllm_config) __all__ = [ - "get_model", "get_model_loader", "BaseModelLoader", - "get_architecture_class_name", "get_model_architecture" + "get_model", + "get_model_loader", + "get_architecture_class_name", + "get_model_architecture", + "BaseModelLoader", + "BitsAndBytesModelLoader", + "GGUFModelLoader", + "DefaultModelLoader", + "DummyModelLoader", + "RunaiModelStreamerLoader", + "ShardedStateLoader", + "TensorizerLoader", ] diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py new file mode 100644 index 0000000000000..f17cab05c25d3 --- /dev/null +++ b/vllm/model_executor/model_loader/base_loader.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod + +import torch.nn as nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + """Load a model with the given configurations.""" + raise NotImplementedError diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py new file mode 100644 index 0000000000000..57189bfafc06e --- /dev/null +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: SIM117 +import copy +import fnmatch +import glob +import itertools +import math +import os +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import numpy as np +import torch +from huggingface_hub import HfApi +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +# yapf: enable +from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import (ParamMapping, + initialize_model, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] + # Save the module names that are sharded by column. + self.column_sharded_weights_modules: List[str] = [] + # Store all module names (from transformers) that support + # BNB quantization. + self.target_modules: List[str] = [] + # mapping weight names from transformers to vllm. + self.weight_mapper: Callable = lambda name: name + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None, + ) -> Tuple[str, List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return model_name_or_path, weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return hf_folder, glob.glob( + os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + use_safetensors = matched_pattern == "*.safetensors" + is_local = os.path.isdir(model_name_or_path) + index_file = SAFE_WEIGHTS_INDEX_NAME + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, use_safetensors + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + ) + for org_name, param in iterator: + # mapping weight names from transformers to vllm while preserving + # original names. + mapped_name = self.weight_mapper(org_name) + yield org_name, mapped_name, param + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + + if bitsandbytes.__version__ < "0.45.3": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.45.3.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.45.3 via " + "`pip install bitsandbytes>=0.45.3` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict: Dict[str, Any] = {} + + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) + for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if not mapped_weight_name.lower().endswith(".scb"): + continue + + weight_key = mapped_weight_name.lower().replace(".scb", ".weight") + quant_state_dict[weight_key] = weight_tensor + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_8bit_weight_name(mapped_weight_name): + continue + + if mapped_weight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in weight_iterator: + if not self._is_4bit_weight_name(mapped_weight_name): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in mapped_weight_name: + temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[mapped_weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, + device=current_platform.device_type) + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_4bit_weight_name(mapped_weight_name): + continue + + if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" + in temp_state_dict) or ( + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" + in temp_state_dict): + quant_state = _parse_quant_state(mapped_weight_name, + temp_state_dict) + quant_state_dict[mapped_weight_name] = quant_state + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if any(target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): + # Without sharding + if any( + mapped_weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any( + mapped_weight_name.startswith(module) + for module in self.column_sharded_weights_modules): + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + # Weights have fused on disk. In this case, we assume that the + # weight and module use same name. + elif any( + mapped_weight_name.startswith(module) + for module in self.maybe_fused_weights_modules): + # special case for fused weights + # get the size of each shard weight tensor + total_shard_sizes = next( + (sizes for module, sizes in + self.maybe_fused_weights_modules.items() + if mapped_weight_name.startswith(module))) + total_size = weight_tensor.size(0) + assert total_size == sum(total_shard_sizes) + # get the start/end index of each shard weight tensor + total_start_index = list( + itertools.accumulate([0] + total_shard_sizes))[:-1] + shard_weights_index = [( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) for idx, size in zip(total_start_index, + total_shard_sizes)] + # slice and reorder the weight tensor + weight_tensor = [ + weight_tensor[start_index:end_index, ...] + for start_index, end_index in shard_weights_index + ] + weight_sub_tensor = torch.cat(weight_tensor, dim=0) + # Shard by row + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4", + ) + + quant_state_dict[mapped_weight_name] = quant_state + else: + processed_weight = weight_tensor + yield org_weight_name, processed_weight + + def _get_bnb_target_modules(self, model: nn.Module) -> None: + + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + if modules_info := self.modules_mapping.get_sub_modules(name): + # Map vllm's names to transformers's names. + rep_name, sub_modules = modules_info + for sub_name in sub_modules: + self.target_modules.append( + name.replace(rep_name, sub_name)) + # Add original module name even if the module has stacked map, + # in case model has a mixture of disk-merged and disk-splitted + # weights with same last name. + self.target_modules.append(name) + + assert (self.target_modules + ), "vllm currently does not support BNB quantization for" + f" {type(model).__name__}" + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") + + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") + + self.modules_mapping = ParamMapping( + copy.deepcopy(model.packed_modules_mapping)) + + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: Dict[str, List[int]] = {} + self._get_bnb_target_modules(model) + for name, module in model.named_modules(): + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + if isinstance(module, (ReplicatedLinear, )): + self.unsharded_weights_modules.append(name) + # `QKVParallelLinear` and `MergedColumnParallelLinear` might have + # fused weights on disk. We need to use the output sizes of these + # modules to shard the weights correctly. + elif isinstance(module, + (QKVParallelLinear, MergedColumnParallelLinear)): + self.maybe_fused_weights_modules[name] = module.output_sizes + # In TP, these weights are partitioned along the column + # dimension (dim=-1) + elif isinstance(module, (RowParallelLinear, )): + self.column_sharded_weights_modules.append(name) + + self.model_type = type(model).__name__ + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get("quant_method") + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with tensor parallelism is not " + "supported. Please try with pipeline parallelism.") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get("load_in_8bit", False) + + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator(model_config.model, + model_config.revision, + pre_quant, load_8bit)) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + # TODO: Change this lazy import to normal import + # after the checks are updated to run on a new version + from vllm.model_executor.models.utils import is_pp_missing_parameter + + for quant_param_name in quant_state_dict: + if is_pp_missing_parameter(quant_param_name, model): + continue + + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, + index, + ) in self.modules_mapping.inverse_packed_mapping.items(): + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight + shard_pos = quant_param_name.find(shard_name) + can_correct_rename = (shard_pos + > 0) and (quant_param_name[shard_pos - 1] + == ".") + # If the quant_param_name is packed, it won't occur in the + # param_dict before renaming. + new_quant_param_name = quant_param_name.replace( + shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) \ + and (new_quant_param_name in param_dict) + if can_correct_rename and need_rename: + shard_index = index + quant_param_name = new_quant_param_name + break + + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. + if quant_param_name not in param_dict: + continue + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = (math.prod(quant_state.shape) // + pack_ratio) + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + # Make torch infer_schema happy + offsets = torch.tensor(offsets).cpu() + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + + model = initialize_model(vllm_config=vllm_config) + + self._load_weights(model_config, model) + + return model.eval() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py new file mode 100644 index 0000000000000..c8bc4aecaecf6 --- /dev/null +++ b/vllm/model_executor/model_loader/default_loader.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: Apache-2.0 +import dataclasses +import glob +import os +import time +from typing import Generator, Iterable, List, Optional, Tuple, cast + +import huggingface_hub +import torch +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: Optional[list[str]] = None + """If defined, weights will load exclusively using these patterns.""" + + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, self.load_config.download_dir): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants. + HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt, + source.allow_patterns_overrides) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + elif use_safetensors: + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: + weights_iterator = fastsafetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + ) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + + elif current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore + + def _hpu_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + htcore.mark_step() + + weights_iterator = _hpu_weights_iterator(weights_iterator) + + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", + None), + ) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model)) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + process_weights_after_loading(model, model_config, target_device) + + return model.eval() diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py new file mode 100644 index 0000000000000..5047a161f3f97 --- /dev/null +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import torch.nn as nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + initialize_dummy_weights) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + process_weights_after_loading(model, model_config, target_device) + return model.eval() diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py new file mode 100644 index 0000000000000..ace1cd371286c --- /dev/null +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Dict, Generator, Tuple + +import gguf +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + gguf_to_hf_name_map = {} + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + if model_type in ("deepseek_v3", "deepseek_v2"): + model_type = "deepseek2" + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config( + config, trust_remote_code=model_config.trust_remote_code) + state_dict = dummy_model.state_dict() + + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, + gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map): + model_config.hf_config.update({"tie_word_embeddings": True}) + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map)) + + process_weights_after_loading(model, model_config, target_device) + return model diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py deleted file mode 100644 index 01f75db9ee862..0000000000000 --- a/vllm/model_executor/model_loader/loader.py +++ /dev/null @@ -1,1544 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# ruff: noqa: SIM117 -import collections -import copy -import dataclasses -import fnmatch -import glob -import inspect -import itertools -import math -import os -import time -import warnings -from abc import ABC, abstractmethod -from contextlib import contextmanager -from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, - Tuple, cast) - -import gguf -import huggingface_hub -import numpy as np -import torch -from huggingface_hub import HfApi -from torch import nn -from transformers import AutoModelForCausalLM -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME - -from vllm.attention import Attention -from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, - VllmConfig, set_current_vllm_config) -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.envs import VLLM_USE_MODELSCOPE -from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.layers.linear import (LinearBase, - MergedColumnParallelLinear, - QKVCrossParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) -from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, - serialize_vllm_model, tensorizer_weights_iterator) -from vllm.model_executor.model_loader.utils import (ParamMapping, - configure_quant_config, - get_model_architecture, - set_default_torch_dtype) -from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, - get_lock, gguf_quant_weights_iterator, initialize_dummy_weights, - np_cache_weights_iterator, pt_weights_iterator, - runai_safetensors_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.transformers_utils.s3_utils import glob as s3_glob -from vllm.transformers_utils.utils import is_s3 -from vllm.utils import is_pin_memory_available - - -@contextmanager -def device_loading_context(module: torch.nn.Module, - target_device: torch.device): - if target_device.type == "cpu": - # If target is CPU, no need to move anything - yield module - return - - original_device_states: Dict[str, torch.device] = {} - - # Store original device states and move parameters to GPU if they're on CPU - for name, p in module.named_parameters(): - if p.device.type == "cpu": - original_device_states[name] = p.device - p.data = p.data.to(target_device) - # Parameters already on target device are not touched - - try: - yield module - - finally: - # Restore parameters to their original devices, ignoring new parameters - pin_memory = is_pin_memory_available() - for name, p in module.named_parameters(): - if name in original_device_states: - original_device: torch.device = original_device_states[name] - if original_device.type == "cpu": - # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided( - size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_data.copy_(p.data) - p.data = cpu_data - else: - p.data = p.data.to(original_device) - # New parameters or parameters already on target device are untouched - - -logger = init_logger(__name__) - - -def _initialize_model( - vllm_config: VllmConfig, - *, - prefix: str = "", - model_class: Optional[type[nn.Module]] = None, -) -> nn.Module: - """Initialize a model with the given configurations.""" - model_config = vllm_config.model_config - if model_class is None: - model_class, _ = get_model_architecture(model_config) - - if vllm_config.quant_config is not None: - configure_quant_config(vllm_config.quant_config, model_class) - - signatures = inspect.signature(model_class.__init__) - all_params = [param.name for param in signatures.parameters.values()] - if "vllm_config" in all_params and "prefix" in all_params: - # new-style model class - with set_current_vllm_config(vllm_config, check_compile=True): - return model_class(vllm_config=vllm_config, prefix=prefix) - - msg = ("vLLM model class should accept `vllm_config` and `prefix` as " - "input arguments. Possibly you have an old-style model class" - " registered from out of tree and it is used for new vLLM version. " - "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " - "for the design and update the model class accordingly.") - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - logger.warning( - "Trying to guess the arguments for old-style model class %s", - model_class, - ) - # try to be compatible with old-style model class - kwargs = {} - if "prefix" in all_params: - kwargs["prefix"] = prefix - if "config" in all_params: - kwargs["config"] = model_config.hf_config - if "cache_config" in all_params: - kwargs["cache_config"] = vllm_config.cache_config - if "quant_config" in all_params: - kwargs["quant_config"] = vllm_config.quant_config - if "lora_config" in all_params: - kwargs["lora_config"] = vllm_config.lora_config - if "scheduler_config" in all_params: - kwargs["scheduler_config"] = vllm_config.scheduler_config - with set_current_vllm_config(vllm_config, check_compile=True): - return model_class(**kwargs) - - -def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, - target_device: torch.device) -> None: - for _, module in model.named_modules(): - if isinstance(module, QKVCrossParallelLinear): - # NOTE(Isotr0py): special case for cross QKV layer because - # q and kv proj aren't registered as submodules intentionally - module.process_weights_after_loading() - continue - quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - - # Currently only used by MLA. - # NOTE: This intentionally happens after other modules so we can easily - # decompress the weights for MLA. - for _, module in model.named_modules(): - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # TODO(lucas): see if there is a way to unify the signatures - # of process_weights_after_loading - module.process_weights_after_loading(model_config.dtype) - - -class BaseModelLoader(ABC): - """Base class for model loaders.""" - - def __init__(self, load_config: LoadConfig): - self.load_config = load_config - - @abstractmethod - def download_model(self, model_config: ModelConfig) -> None: - """Download a model so that it can be immediately loaded.""" - raise NotImplementedError - - @abstractmethod - def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: - """Load a model with the given configurations.""" - raise NotImplementedError - - -class DefaultModelLoader(BaseModelLoader): - """Model loader that can load different file types from disk.""" - - @dataclasses.dataclass - class Source: - """A source for weights.""" - - model_or_path: str - """The model ID or path.""" - - revision: Optional[str] - """The optional model revision.""" - - prefix: str = "" - """A prefix to prepend to all weights.""" - - fall_back_to_pt: bool = True - """Whether .pt weights can be used.""" - - allow_patterns_overrides: Optional[list[str]] = None - """If defined, weights will load exclusively using these patterns.""" - - counter_before_loading_weights: float = 0.0 - counter_after_loading_weights: float = 0.0 - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. - - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - if not os.path.exists(model): - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None - - def _prepare_weights( - self, - model_name_or_path: str, - revision: Optional[str], - fall_back_to_pt: bool, - allow_patterns_overrides: Optional[list[str]], - ) -> Tuple[str, List[str], bool]: - """Prepare weights for the model. - - If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( - model_name_or_path, revision) or model_name_or_path) - - is_local = os.path.isdir(model_name_or_path) - load_format = self.load_config.load_format - use_safetensors = False - index_file = SAFE_WEIGHTS_INDEX_NAME - # Some quantized models use .pt files for storing the weights. - if load_format == LoadFormat.AUTO: - allow_patterns = ["*.safetensors", "*.bin"] - elif (load_format == LoadFormat.SAFETENSORS - or load_format == LoadFormat.FASTSAFETENSORS): - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == LoadFormat.MISTRAL: - use_safetensors = True - allow_patterns = ["consolidated*.safetensors"] - index_file = "consolidated.safetensors.index.json" - elif load_format == LoadFormat.PT: - allow_patterns = ["*.pt"] - elif load_format == LoadFormat.NPCACHE: - allow_patterns = ["*.bin"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if allow_patterns_overrides is not None: - allow_patterns = allow_patterns_overrides - - if not is_local: - hf_folder = download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - allow_patterns, - revision, - ignore_patterns=self.load_config.ignore_patterns, - ) - else: - hf_folder = model_name_or_path - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if use_safetensors: - # For models like Mistral-7B-Instruct-v0.3 - # there are both sharded safetensors files and a consolidated - # safetensors file. Using both breaks. - # Here, we download the `model.safetensors.index.json` and filter - # any files not found in the index. - if not is_local: - download_safetensors_index_file_from_hf( - model_name_or_path, - index_file, - self.load_config.download_dir, - revision, - ) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) - else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors - - def _get_weights_iterator( - self, source: "Source" - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Get an iterator for the model weights based on the load format.""" - hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - source.model_or_path, source.revision, source.fall_back_to_pt, - source.allow_patterns_overrides) - if self.load_config.load_format == LoadFormat.NPCACHE: - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - weights_iterator = np_cache_weights_iterator( - source.model_or_path, - self.load_config.download_dir, - hf_folder, - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - elif use_safetensors: - if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: - weights_iterator = fastsafetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - else: - weights_iterator = safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - else: - weights_iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) - - if current_platform.is_tpu(): - # In PyTorch XLA, we should call `xm.mark_step` frequently so that - # not too many ops are accumulated in the XLA program. - import torch_xla.core.xla_model as xm - - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - xm.mark_step() - - weights_iterator = _xla_weights_iterator(weights_iterator) - - elif current_platform.is_hpu(): - import habana_frameworks.torch.core as htcore - - def _hpu_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - htcore.mark_step() - - weights_iterator = _hpu_weights_iterator(weights_iterator) - - if self.counter_before_loading_weights == 0.0: - self.counter_before_loading_weights = time.perf_counter() - # Apply the prefix. - return ((source.prefix + name, tensor) - for (name, tensor) in weights_iterator) - - def get_all_weights( - self, - model_config: ModelConfig, - model: nn.Module, - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - primary_weights = DefaultModelLoader.Source( - model_config.model, - model_config.revision, - prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", - True), - allow_patterns_overrides=getattr(model, "allow_patterns_overrides", - None), - ) - yield from self._get_weights_iterator(primary_weights) - - secondary_weights = cast( - Iterable[DefaultModelLoader.Source], - getattr(model, "secondary_weights", ()), - ) - for source in secondary_weights: - yield from self._get_weights_iterator(source) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, - model_config.revision, - fall_back_to_pt=True, - allow_patterns_overrides=None) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) - self.counter_after_loading_weights = time.perf_counter() - logger.info( - "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") - - _process_weights_after_loading(model, model_config, target_device) - - return model.eval() - - -class DummyModelLoader(BaseModelLoader): - """Model loader that will set model weights to random values.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - - def download_model(self, model_config: ModelConfig) -> None: - pass # Nothing to download - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - - _process_weights_after_loading(model, model_config, target_device) - return model.eval() - - -class TensorizerLoader(BaseModelLoader): - """Model loader using CoreWeave's tensorizer library.""" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if isinstance(load_config.model_loader_extra_config, TensorizerConfig): - self.tensorizer_config = load_config.model_loader_extra_config - else: - self.tensorizer_config = TensorizerConfig( - **load_config.model_loader_extra_config) - - def _verify_config(self, model_config: ModelConfig, - parallel_config: ParallelConfig): - self.tensorizer_config.verify_with_model_config(model_config) - self.tensorizer_config.verify_with_parallel_config(parallel_config) - - def _get_weights_iterator( - self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - tensorizer_args = self.tensorizer_config._construct_tensorizer_args() - return tensorizer_weights_iterator(tensorizer_args) - - def _load_model_serialized_cpu( - self, - vllm_config: VllmConfig, - ) -> nn.Module: - """Load a serialized model with tensorizer to the CPU. - - This is only necessary when the model isn't vLLM-tensorized (see - examples/other/tensorize_vllm_model.py) This should still - be faster than default HuggingFace loading, but will be slower than - loading a vLLM-tensorized model. - """ - device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(vllm_config=vllm_config) - - model.load_weights(self._get_weights_iterator()) - return model.eval() - - def _load_model_serialized( - self, - vllm_config: VllmConfig, - ) -> nn.Module: - """Load a serialized model with tensorizer. - - Expects a vLLM-tensorized model. See the - examples/other/tensorize_vllm_model.py example script - for serializing vLLM models.""" - - device_config = vllm_config.device_config - model_config = vllm_config.model_config - - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model_class = get_model_architecture(model_config)[0] - - tensorizer_config = copy.copy(self.tensorizer_config) - tensorizer_config.model_class = model_class - tensorizer_config.hf_config = model_config.hf_config - tensorizer_config.dtype = model_config.dtype - - model = load_with_tensorizer(tensorizer_config, - vllm_config=vllm_config) - return model.eval() - - def download_model(self, model_config: ModelConfig) -> None: - self.tensorizer_config.verify_with_model_config(model_config) - - with self.tensorizer_config.open_stream(): - pass - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - self._verify_config(model_config, parallel_config) - - if parallel_config.tensor_parallel_size > 1: - from vllm.distributed import get_tensor_model_parallel_rank - - self.tensorizer_config.tensorizer_uri = ( - self.tensorizer_config.tensorizer_uri % - get_tensor_model_parallel_rank()) - - if is_vllm_tensorized(self.tensorizer_config): - return self._load_model_serialized(vllm_config=vllm_config) - return self._load_model_serialized_cpu(vllm_config=vllm_config) - - @staticmethod - def save_model( - model: torch.nn.Module, - tensorizer_config: TensorizerConfig, - ) -> None: - serialize_vllm_model( - model=model, - tensorizer_config=tensorizer_config, - ) - - -class ShardedStateLoader(BaseModelLoader): - """ - Model loader that directly loads each worker's model state dict, which - enables a fast load path for large tensor-parallel models where each worker - only needs to read its own shard rather than the entire checkpoint. See - `examples/offline_inference/save_sharded_state.py` for creating a sharded - checkpoint. - """ - - DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - - def __init__(self, - load_config: LoadConfig, - runai_model_streamer: bool = False): - super().__init__(load_config) - - self.runai_model_streamer = runai_model_streamer - extra_config = ({} if load_config.model_loader_extra_config is None - else load_config.model_loader_extra_config.copy()) - self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) - if extra_config: - raise ValueError(f"Unexpected extra config keys for load format " - f"{load_config.load_format}: " - f"{load_config.model_loader_extra_config.keys()}") - - @staticmethod - def _filter_subtensors( - tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - """ - Filter out all tensors that share the same memory or a subset of the - memory of another tensor. - """ - same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( - collections.defaultdict(list)) - for key, tensor in tensors.items(): - if tensor.numel(): - ptr = tensor.untyped_storage().data_ptr() - same_storage_groups[tensor.device, ptr].append((key, tensor)) - - def get_end_ptr(tensor: torch.Tensor) -> int: - return tensor.view(-1)[-1].data_ptr() + tensor.element_size() - - result: Dict[str, torch.Tensor] = {} - for group in same_storage_groups.values(): - for k, t in group: - a, b = t.data_ptr(), get_end_ptr(t) - for k2, t2 in group: - if not t2.is_contiguous(): - continue - a2, b2 = t2.data_ptr(), get_end_ptr(t2) - if a < a2 or b2 < b: - continue - if a2 < a or b < b2 or not t.is_contiguous(): - break # t2 covers strictly more memory than t. - if k2 < k: - # Same tensors, keep the one with the smaller key. - break - else: - result[k] = t - return result - - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]): - if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): - return model_name_or_path - else: - allow_patterns = ["*.safetensors"] - return download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - allow_patterns, - revision, - ignore_patterns=self.load_config.ignore_patterns, - ) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) - - from vllm.distributed import get_tensor_model_parallel_rank - - model_weights = model_config.model - if hasattr(model_config, "model_weights"): - model_weights = model_config.model_weights - local_model_path = model_weights - - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - _process_weights_after_loading(model, model_config, - target_device) - rank = get_tensor_model_parallel_rank() - pattern = os.path.join( - local_model_path, - self.pattern.format(rank=rank, part="*"), - ) - - filepaths = [] - if is_s3(local_model_path): - file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" - filepaths = s3_glob(path=local_model_path, - allow_pattern=[file_pattern]) - else: - filepaths = glob.glob(pattern) - if not filepaths: - # TODO: support un-sharded checkpoints too - raise ValueError( - f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!") - state_dict = self._filter_subtensors(model.state_dict()) - for key, tensor in self.iterate_over_files(filepaths): - # If loading with LoRA enabled, additional padding may - # be added to certain parameters. We only load into a - # narrowed view of the parameter data. - param_data = state_dict[key].data - param_shape = state_dict[key].shape - for dim, size in enumerate(tensor.shape): - if size < param_shape[dim]: - param_data = param_data.narrow(dim, 0, size) - if tensor.shape != param_shape: - logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, - key, - param_shape, - ) - param_data.copy_(tensor) - state_dict.pop(key) - if state_dict: - raise ValueError( - f"Missing keys {tuple(state_dict)} in loaded state!") - return model.eval() - - def iterate_over_files( - self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: - if self.runai_model_streamer: - yield from runai_safetensors_weights_iterator(paths, True) - else: - from safetensors.torch import safe_open - for path in paths: - with safe_open(path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - tensor = f.get_tensor(key) - yield key, tensor - - @staticmethod - def save_model( - model: torch.nn.Module, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from safetensors.torch import save_file - - from vllm.distributed import get_tensor_model_parallel_rank - - if pattern is None: - pattern = ShardedStateLoader.DEFAULT_PATTERN - rank = get_tensor_model_parallel_rank() - part_idx = 0 - total_size = 0 - state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) - state_dict_part: Dict[str, torch.Tensor] = {} - for key, tensor in state_dict.items(): - param_size = tensor.nelement() * tensor.element_size() - if max_size is not None and total_size + param_size > max_size: - filename = pattern.format(rank=rank, part=part_idx) - save_file( - state_dict_part, - os.path.join(path, filename), - ) - part_idx += 1 - total_size = 0 - state_dict_part = {} - state_dict_part[key] = tensor - total_size += param_size - if len(state_dict_part) > 0: - filename = pattern.format(rank=rank, part=part_idx) - save_file( - state_dict_part, - os.path.join(path, filename), - ) - - -class BitsAndBytesModelLoader(BaseModelLoader): - """Model loader to load model weights with BitAndBytes quantization.""" - - possible_config_file_names = ["adapter_config.json"] - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - - # Save the module names without sharding. - self.unsharded_weights_modules: List[str] = [] - # Save the module names that are sharded by column. - self.column_sharded_weights_modules: List[str] = [] - # Store all module names (from transformers) that support - # BNB quantization. - self.target_modules: List[str] = [] - # mapping weight names from transformers to vllm. - self.weight_mapper: Callable = lambda name: name - - def _get_weight_files( - self, - model_name_or_path: str, - allowed_patterns: List[str], - revision: Optional[str] = None, - ) -> Tuple[str, List[str], str]: - """Retrieve weight files. Download the files if necessary. - - Return the weight files and the file pattern.""" - is_local = os.path.isdir(model_name_or_path) - - if is_local: - for pattern in allowed_patterns: - weight_files = glob.glob( - os.path.join(model_name_or_path, pattern)) - if weight_files: - return model_name_or_path, weight_files, pattern - else: - hf_api = HfApi() - repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) - for pattern in allowed_patterns: - matching_files = fnmatch.filter(repo_files, pattern) - if matching_files: - hf_folder = download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - [pattern], - revision, - ignore_patterns=self.load_config.ignore_patterns, - ) - return hf_folder, glob.glob( - os.path.join(hf_folder, pattern)), pattern - - raise RuntimeError( - f"No model weights found in: `{model_name_or_path}`") - - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> Tuple[List[str], bool]: - """Prepare weight files for the model.""" - - allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] - - hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( - model_name_or_path, allowed_patterns, revision) - - use_safetensors = matched_pattern == "*.safetensors" - is_local = os.path.isdir(model_name_or_path) - index_file = SAFE_WEIGHTS_INDEX_NAME - if use_safetensors: - # For models like Mistral-7B-Instruct-v0.3 - # there are both sharded safetensors files and a consolidated - # safetensors file. Using both breaks. - # Here, we download the `model.safetensors.index.json` and filter - # any files not found in the index. - if not is_local: - download_safetensors_index_file_from_hf( - model_name_or_path, - index_file, - self.load_config.download_dir, - revision, - ) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) - else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_weights_files, use_safetensors - - def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - if use_safetensors: - iterator = safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - else: - iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) - for org_name, param in iterator: - # mapping weight names from transformers to vllm while preserving - # original names. - mapped_name = self.weight_mapper(org_name) - yield org_name, mapped_name, param - - def _get_quantized_weights_iterator( - self, - model_name_or_path: str, - revision: Optional[str], - pre_quant: bool, - load_8bit: bool, - ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, - Any]]: - """Get an iterator to the model weights with bitsandbytes quantization, - as well as the quantization state dictionary.""" - - # only load the bitsandbytes module when needed - try: - import bitsandbytes - - if bitsandbytes.__version__ < "0.45.3": - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.45.3.") - except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.45.3 via " - "`pip install bitsandbytes>=0.45.3` to use " - "bitsandbytes quantizer.") from err - - hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision) - - quant_state_dict: Dict[str, Any] = {} - - if pre_quant: - if load_8bit: - return self._quantized_8bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict - else: - return self._quantized_4bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict - - return self._unquantized_generator(hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict - - def _is_8bit_weight_name(self, weight_name: str): - quantized_suffix = {".scb", ".weight_format"} - return any(weight_name.lower().endswith(suffix) - for suffix in quantized_suffix) - - def _is_4bit_weight_name(self, weight_name: str): - quantized_suffix = { - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "bitsandbytes", - } - suffix = weight_name.split(".")[-1] - return any(q_suffix in suffix for q_suffix in quantized_suffix) - - def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if not mapped_weight_name.lower().endswith(".scb"): - continue - - weight_key = mapped_weight_name.lower().replace(".scb", ".weight") - quant_state_dict[weight_key] = weight_tensor - - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if self._is_8bit_weight_name(mapped_weight_name): - continue - - if mapped_weight_name in quant_state_dict: - set_weight_attrs(weight_tensor, {"load_in_8bit": True}) - yield org_weight_name, weight_tensor - else: - yield org_weight_name, weight_tensor - - def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: - from bitsandbytes.functional import QuantState - - # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) - temp_state_dict = {} - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in weight_iterator: - if not self._is_4bit_weight_name(mapped_weight_name): - continue - # bitsandbytes library requires - # weight.quant_state.bitsandbytes__* in CPU - if "quant_state.bitsandbytes" in mapped_weight_name: - temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data - else: - temp_state_dict[mapped_weight_name] = weight_tensor - - # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: Dict) -> QuantState: - quant_state = {} - for k in temp_state_dict: - if param_name + "." in k: - quant_state[k] = temp_state_dict[k] - - return QuantState.from_dict(quant_state, - device=current_platform.device_type) - - # Second iterate over all prequant and normal weights - # pre quantized weights would have a quant_state - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if self._is_4bit_weight_name(mapped_weight_name): - continue - - if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" - in temp_state_dict) or ( - f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" - in temp_state_dict): - quant_state = _parse_quant_state(mapped_weight_name, - temp_state_dict) - quant_state_dict[mapped_weight_name] = quant_state - yield org_weight_name, weight_tensor - else: - yield org_weight_name, weight_tensor - - def _unquantized_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: - from bitsandbytes.functional import quantize_4bit - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - for ( - org_weight_name, - mapped_weight_name, - weight_tensor, - ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - if any(target_module in mapped_weight_name - for target_module in self.target_modules - ) and mapped_weight_name.endswith(".weight"): - # Without sharding - if any( - mapped_weight_name.startswith(module) - for module in self.unsharded_weights_modules): - weight_sub_tensor = weight_tensor - # Shard by column - elif any( - mapped_weight_name.startswith(module) - for module in self.column_sharded_weights_modules): - total_size = weight_tensor.size(-1) - start_index = total_size // tp_size * tp_rank - end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[..., - start_index:end_index] - # Weights have fused on disk. In this case, we assume that the - # weight and module use same name. - elif any( - mapped_weight_name.startswith(module) - for module in self.maybe_fused_weights_modules): - # special case for fused weights - # get the size of each shard weight tensor - total_shard_sizes = next( - (sizes for module, sizes in - self.maybe_fused_weights_modules.items() - if mapped_weight_name.startswith(module))) - total_size = weight_tensor.size(0) - assert total_size == sum(total_shard_sizes) - # get the start/end index of each shard weight tensor - total_start_index = list( - itertools.accumulate([0] + total_shard_sizes))[:-1] - shard_weights_index = [( - idx + size // tp_size * tp_rank, - idx + size // tp_size * (tp_rank + 1), - ) for idx, size in zip(total_start_index, - total_shard_sizes)] - # slice and reorder the weight tensor - weight_tensor = [ - weight_tensor[start_index:end_index, ...] - for start_index, end_index in shard_weights_index - ] - weight_sub_tensor = torch.cat(weight_tensor, dim=0) - # Shard by row - else: - total_size = weight_tensor.size(0) - start_index = total_size // tp_size * tp_rank - end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[start_index:end_index, - ...] - - # bitsandbytes requires data in GPU - if weight_sub_tensor.is_cuda: - loaded_weight = weight_sub_tensor - else: - loaded_weight = weight_sub_tensor.cuda() - - # remove the following after the issue is fixed: - # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 - if loaded_weight.is_contiguous() is False: - loaded_weight = loaded_weight.contiguous() - - with set_default_torch_dtype(torch.float32): - processed_weight, quant_state = quantize_4bit( - loaded_weight, - compress_statistics=True, - quant_type="nf4", - ) - - quant_state_dict[mapped_weight_name] = quant_state - else: - processed_weight = weight_tensor - yield org_weight_name, processed_weight - - def _get_bnb_target_modules(self, model: nn.Module) -> None: - - for name, module in model.named_modules(): - if isinstance(module, (LinearBase, )): - if modules_info := self.modules_mapping.get_sub_modules(name): - # Map vllm's names to transformers's names. - rep_name, sub_modules = modules_info - for sub_name in sub_modules: - self.target_modules.append( - name.replace(rep_name, sub_name)) - # Add original module name even if the module has stacked map, - # in case model has a mixture of disk-merged and disk-splitted - # weights with same last name. - self.target_modules.append(name) - - assert (self.target_modules - ), "vllm currently does not support BNB quantization for" - f" {type(model).__name__}" - - def _load_weights(self, model_config: ModelConfig, - model: nn.Module) -> None: - if not hasattr(model, "load_weights"): - raise AttributeError( - "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") - - if not hasattr(model, "packed_modules_mapping"): - raise AttributeError( - f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") - - self.modules_mapping = ParamMapping( - copy.deepcopy(model.packed_modules_mapping)) - - # For some models like Molmo, we need to use hf_to_vllm_mapper - # to ensure correct loading of weights. - if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): - self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) - - # Modules whose weights might have fused on disk - # we need their output_sizes to make shard in flight correctly with TP - self.maybe_fused_weights_modules: Dict[str, List[int]] = {} - self._get_bnb_target_modules(model) - for name, module in model.named_modules(): - # Some modules like `ReplicatedLinear` should not have their weights - # sharded. The reason for implementing it this way is to avoid new - # static variable in the model implementation. - if isinstance(module, (ReplicatedLinear, )): - self.unsharded_weights_modules.append(name) - # `QKVParallelLinear` and `MergedColumnParallelLinear` might have - # fused weights on disk. We need to use the output sizes of these - # modules to shard the weights correctly. - elif isinstance(module, - (QKVParallelLinear, MergedColumnParallelLinear)): - self.maybe_fused_weights_modules[name] = module.output_sizes - # In TP, these weights are partitioned along the column - # dimension (dim=-1) - elif isinstance(module, (RowParallelLinear, )): - self.column_sharded_weights_modules.append(name) - - self.model_type = type(model).__name__ - - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") - - quant_config = getattr(model_config.hf_config, "quantization_config", - None) - - pre_quant = False - if quant_config is not None: - quant_method = quant_config.get("quant_method") - if quant_method == "bitsandbytes": - pre_quant = True - else: - raise ValueError( - f"BitsAndBytes loader does not support {quant_method} " - "quantization") - - # The quant_states in pre_quantized models cannot work with a split - # weight tensor. So TP does not work with pre_quantized bnb models. - if pre_quant and get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "Prequant BitsAndBytes models with tensor parallelism is not " - "supported. Please try with pipeline parallelism.") - - load_8bit = False - if pre_quant: - load_8bit = quant_config.get("load_in_8bit", False) - - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator(model_config.model, - model_config.revision, - pre_quant, load_8bit)) - - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights(qweight_iterator) - # Some models may have weights loading tracker unimplemented. - if loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") - - torch.cuda.empty_cache() - - param_dict = dict(model.named_parameters()) - stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} - # TODO: Change this lazy import to normal import - # after the checks are updated to run on a new version - from vllm.model_executor.models.utils import is_pp_missing_parameter - - for quant_param_name in quant_state_dict: - if is_pp_missing_parameter(quant_param_name, model): - continue - - non_stacked_param_name = quant_param_name - - shard_index = 0 - for shard_name, ( - weight_name, - index, - ) in self.modules_mapping.inverse_packed_mapping.items(): - # Some models, such as MiniCPM V2.5/2.6, contain both - # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' - # from being incorrectly identified as being present in - # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight - shard_pos = quant_param_name.find(shard_name) - can_correct_rename = (shard_pos - > 0) and (quant_param_name[shard_pos - 1] - == ".") - # If the quant_param_name is packed, it won't occur in the - # param_dict before renaming. - new_quant_param_name = quant_param_name.replace( - shard_name, weight_name) - need_rename = (quant_param_name not in param_dict) \ - and (new_quant_param_name in param_dict) - if can_correct_rename and need_rename: - shard_index = index - quant_param_name = new_quant_param_name - break - - # Models like Clip/Siglip may skip some layers in initialization, - # causing unused quant_param_name in state_dict. - if quant_param_name not in param_dict: - continue - - if quant_param_name not in stacked_quant_state_dict: - stacked_quant_state_dict[quant_param_name] = {} - - stacked_quant_state_dict[quant_param_name][shard_index] = ( - quant_state_dict[non_stacked_param_name]) - - # save quant_states and offsets as the attributes of the parameters - for param_name, param in param_dict.items(): - if param_name in stacked_quant_state_dict: - quant_states = stacked_quant_state_dict[param_name] - set_weight_attrs(param, {"bnb_quant_state": quant_states}) - - pack_ratio = getattr(param, "pack_factor", -1) - if pack_ratio == -1: - raise ValueError( - f"pack_factor not set for parameter {param_name}.") - - num_elements = [0] * len(quant_states) - for seq, quant_state in quant_states.items(): - num_elements[seq] = (math.prod(quant_state.shape) // - pack_ratio) - - offsets = np.concatenate(([0], np.cumsum(num_elements))) - # Make torch infer_schema happy - offsets = torch.tensor(offsets).cpu() - set_weight_attrs(param, {"bnb_shard_offsets": offsets}) - - if load_8bit: - set_weight_attrs( - param, {"matmul_state": [None] * len(quant_states)}) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(vllm_config=vllm_config) - - self._load_weights(model_config, model) - - return model.eval() - - -class GGUFModelLoader(BaseModelLoader): - """ - Model loader that can load GGUF files. This is useful for loading models - that are quantized with GGUF and saved in the GGUF format. This loader - supports loading both full models and sharded models. - """ - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - - def _prepare_weights(self, model_name_or_path: str): - if os.path.isfile(model_name_or_path): - return model_name_or_path - else: - raise ValueError(f"{model_name_or_path} is not a file.") - - def _get_gguf_weights_map(self, model_config: ModelConfig): - """ - GGUF uses this naming convention for their tensors from HF checkpoint: - `blk.N.BB.weight` and `blk.N.BB.bias` - where N signifies the block number of a layer, and BB signifies the - attention/mlp layer components. - See "Standardized tensor names" in - https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. - """ - config = model_config.hf_config - model_type = config.model_type - gguf_to_hf_name_map = {} - # hack: ggufs have a different name than transformers - if model_type == "cohere": - model_type = "command-r" - if model_type in ("deepseek_v3", "deepseek_v2"): - model_type = "deepseek2" - # GGUF layer map assumes that we will have a merged expert weights - # so we need to map them manually - for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ - f"model.layers.{idx}.mlp.gate.e_score_correction_bias" - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" - - arch = None - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - arch = key - break - if arch is None: - raise RuntimeError(f"Unknown gguf model_type: {model_type}") - num_layers = config.num_hidden_layers - name_map = gguf.get_tensor_name_map(arch, num_layers) - with torch.device("meta"): - dummy_model = AutoModelForCausalLM.from_config( - config, trust_remote_code=model_config.trust_remote_code) - state_dict = dummy_model.state_dict() - - for hf_name in state_dict: - name, suffix = hf_name.rsplit(".", 1) - gguf_name = name_map.get_name(name) - gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name - return gguf_to_hf_name_map - - def _get_weights_iterator( - self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] - ) -> Generator[Tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(model_name_or_path, - gguf_to_hf_name_map) - - def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - local_model_path = self._prepare_weights(model_config.model) - gguf_weights_map = self._get_gguf_weights_map(model_config) - # we can only know if tie word embeddings after mapping weights - if "lm_head.weight" in get_gguf_extra_tensor_names( - local_model_path, gguf_weights_map): - model_config.hf_config.update({"tie_word_embeddings": True}) - - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map)) - - _process_weights_after_loading(model, model_config, target_device) - return model - - -class RunaiModelStreamerLoader(BaseModelLoader): - """ - Model loader that can load safetensors - files from local FS or S3 bucket. - """ - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - if load_config.model_loader_extra_config: - extra_config = load_config.model_loader_extra_config - - if ("concurrency" in extra_config - and isinstance(extra_config.get("concurrency"), int)): - os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( - extra_config.get("concurrency")) - - if ("memory_limit" in extra_config - and isinstance(extra_config.get("memory_limit"), int)): - os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( - extra_config.get("memory_limit")) - - runai_streamer_s3_endpoint = os.getenv( - 'RUNAI_STREAMER_S3_ENDPOINT') - aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') - if (runai_streamer_s3_endpoint is None - and aws_endpoint_url is not None): - os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url - - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> List[str]: - """Prepare weights for the model. - - If the model is not local, it will be downloaded.""" - - is_s3_path = is_s3(model_name_or_path) - is_local = os.path.isdir(model_name_or_path) - safetensors_pattern = "*.safetensors" - index_file = SAFE_WEIGHTS_INDEX_NAME - - hf_folder = (model_name_or_path if - (is_local or is_s3_path) else download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - [safetensors_pattern], - revision, - ignore_patterns=self.load_config.ignore_patterns, - )) - if is_s3_path: - hf_weights_files = s3_glob(path=hf_folder, - allow_pattern=[safetensors_pattern]) - else: - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) - - if not is_local and not is_s3_path: - download_safetensors_index_file_from_hf( - model_name_or_path, index_file, self.load_config.download_dir, - revision) - - if not hf_weights_files: - raise RuntimeError( - f"Cannot find any safetensors model weights with " - f"`{model_name_or_path}`") - - return hf_weights_files - - def _get_weights_iterator( - self, model_or_path: str, - revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Get an iterator for the model weights based on the load format.""" - hf_weights_files = self._prepare_weights(model_or_path, revision) - return runai_safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - ) - - def download_model(self, model_config: ModelConfig) -> None: - """Download model if necessary""" - self._prepare_weights(model_config.model, model_config.revision) - - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - """Perform streaming of the model to destination""" - device_config = vllm_config.device_config - model_config = vllm_config.model_config - - target_device = torch.device(device_config.device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = _initialize_model(vllm_config=vllm_config) - - model_weights = model_config.model - if hasattr(model_config, "model_weights"): - model_weights = model_config.model_weights - model.load_weights( - self._get_weights_iterator(model_weights, - model_config.revision)) - - _process_weights_after_loading(model, model_config, target_device) - return model.eval() - - -def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: - """Get a model loader based on the load format.""" - if isinstance(load_config.load_format, type): - return load_config.load_format(load_config) - - if load_config.load_format == LoadFormat.DUMMY: - return DummyModelLoader(load_config) - - if load_config.load_format == LoadFormat.TENSORIZER: - return TensorizerLoader(load_config) - - if load_config.load_format == LoadFormat.SHARDED_STATE: - return ShardedStateLoader(load_config) - - if load_config.load_format == LoadFormat.BITSANDBYTES: - return BitsAndBytesModelLoader(load_config) - - if load_config.load_format == LoadFormat.GGUF: - return GGUFModelLoader(load_config) - - if load_config.load_format == LoadFormat.RUNAI_STREAMER: - return RunaiModelStreamerLoader(load_config) - - if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: - return ShardedStateLoader(load_config, runai_model_streamer=True) - - return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py new file mode 100644 index 0000000000000..1fbb5ca566448 --- /dev/null +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: SIM117 +import glob +import os +from typing import Generator, List, Optional, Tuple + +import torch +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + runai_safetensors_weights_iterator) +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 + + +class RunaiModelStreamerLoader(BaseModelLoader): + """ + Model loader that can load safetensors + files from local FS or S3 bucket. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv( + 'RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if (runai_streamer_s3_endpoint is None + and aws_endpoint_url is not None): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> List[str]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + + is_s3_path = is_s3(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = (model_name_or_path if + (is_local or is_s3_path) else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + )) + if is_s3_path: + hf_weights_files = s3_glob(path=hf_folder, + allow_pattern=[safetensors_pattern]) + else: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local and not is_s3_path: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, self.load_config.download_dir, + revision) + + if not hf_weights_files: + raise RuntimeError( + f"Cannot find any safetensors model weights with " + f"`{model_name_or_path}`") + + return hf_weights_files + + def _get_weights_iterator( + self, model_or_path: str, + revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_weights_files = self._prepare_weights(model_or_path, revision) + return runai_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + + def download_model(self, model_config: ModelConfig) -> None: + """Download model if necessary""" + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + """Perform streaming of the model to destination""" + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + model.load_weights( + self._get_weights_iterator(model_weights, + model_config.revision)) + + process_weights_after_loading(model, model_config, target_device) + return model.eval() diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py new file mode 100644 index 0000000000000..152a3d699726f --- /dev/null +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 + +import collections +import glob +import os +from typing import Any, Dict, Generator, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, runai_safetensors_weights_iterator) +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 + +logger = init_logger(__name__) + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/offline_inference/save_sharded_state.py` for creating a sharded + checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, + load_config: LoadConfig, + runai_model_streamer: bool = False): + super().__init__(load_config) + + self.runai_model_streamer = runai_model_streamer + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list)) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) + + from vllm.distributed import get_tensor_model_parallel_rank + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + local_model_path = model_weights + + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + process_weights_after_loading(model, model_config, + target_device) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + def iterate_over_files( + self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: + if self.runai_model_streamer: + yield from runai_safetensors_weights_iterator(paths, True) + else: + from safetensors.torch import safe_open + for path in paths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + yield key, tensor + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py new file mode 100644 index 0000000000000..7cf3940ab644a --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: SIM117 +import copy +from typing import Generator, Tuple + +import torch +from torch import nn + +from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, + serialize_vllm_model, tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + initialize_model, + set_default_torch_dtype) + +logger = init_logger(__name__) + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_serialized_cpu( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer to the CPU. + + This is only necessary when the model isn't vLLM-tensorized (see + examples/other/tensorize_vllm_model.py) This should still + be faster than default HuggingFace loading, but will be slower than + loading a vLLM-tensorized model. + """ + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = initialize_model(vllm_config=vllm_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/other/tensorize_vllm_model.py example script + for serializing vLLM models.""" + + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, + vllm_config=vllm_config) + return model.eval() + + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self._verify_config(model_config, parallel_config) + + if parallel_config.tensor_parallel_size > 1: + from vllm.distributed import get_tensor_model_parallel_rank + + self.tensorizer_config.tensorizer_uri = ( + self.tensorizer_config.tensorizer_uri % + get_tensor_model_parallel_rank()) + + if is_vllm_tensorized(self.tensorizer_config): + return self._load_model_serialized(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config) + + @staticmethod + def save_model( + model: torch.nn.Module, + tensorizer_config: TensorizerConfig, + ) -> None: + serialize_vllm_model( + model=model, + tensorizer_config=tensorizer_config, + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0ca6b6fd88b6a..42528cd7e4334 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """Utilities for selecting and loading models.""" import contextlib +import inspect +import warnings +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Type @@ -9,14 +12,18 @@ import transformers from torch import nn from transformers.dynamic_module_utils import get_class_from_dynamic_module -from vllm.config import ModelConfig, ModelImpl +from vllm.attention import Attention +from vllm.config import (ModelConfig, ModelImpl, VllmConfig, + set_current_vllm_config) from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -30,6 +37,128 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) +def initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", + model_class: Optional[type[nn.Module]] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_config = vllm_config.model_config + if model_class is None: + model_class, _ = get_model_architecture(model_config) + + if vllm_config.quant_config is not None: + configure_quant_config(vllm_config.quant_config, model_class) + + signatures = inspect.signature(model_class.__init__) + all_params = [param.name for param in signatures.parameters.values()] + if "vllm_config" in all_params and "prefix" in all_params: + # new-style model class + with set_current_vllm_config(vllm_config, check_compile=True): + return model_class(vllm_config=vllm_config, prefix=prefix) + + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " + "for the design and update the model class accordingly.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + logger.warning( + "Trying to guess the arguments for old-style model class %s", + model_class, + ) + # try to be compatible with old-style model class + kwargs = {} + if "prefix" in all_params: + kwargs["prefix"] = prefix + if "config" in all_params: + kwargs["config"] = model_config.hf_config + if "cache_config" in all_params: + kwargs["cache_config"] = vllm_config.cache_config + if "quant_config" in all_params: + kwargs["quant_config"] = vllm_config.quant_config + if "lora_config" in all_params: + kwargs["lora_config"] = vllm_config.lora_config + if "scheduler_config" in all_params: + kwargs["scheduler_config"] = vllm_config.scheduler_config + with set_current_vllm_config(vllm_config, check_compile=True): + return model_class(**kwargs) + + +def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + for _, module in model.named_modules(): + if isinstance(module, QKVCrossParallelLinear): + # NOTE(Isotr0py): special case for cross QKV layer because + # q and kv proj aren't registered as submodules intentionally + module.process_weights_after_loading() + continue + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): for i, arch in enumerate(architectures): diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 56a7f02c4159d..741b9837398c2 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -670,7 +670,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")) - self.language_model = _initialize_model( + self.language_model = initialize_model( vllm_config=vllm_config.with_hf_config(config.text_config, ["LlamaForCausalLM"]), prefix=maybe_prefix(prefix, "language_model"), diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index bfa48099b7416..0bc5d218f8d06 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -17,7 +17,7 @@ from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1be40ecd3e28b..0458e3ce03b5d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -273,7 +273,7 @@ def init_vllm_registered_model( Helper function to initialize an inner model registered to vLLM, based on the arguments passed to the outer vLLM model. """ - from vllm.model_executor.model_loader.loader import _initialize_model + from vllm.model_executor.model_loader.utils import initialize_model if hf_config is None and architectures is not None: # So that the architectures field is overridden @@ -283,7 +283,7 @@ def init_vllm_registered_model( vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures) - return _initialize_model(vllm_config=vllm_config, prefix=prefix) + return initialize_model(vllm_config=vllm_config, prefix=prefix) @overload diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0989c7e19de96..13cfcc4bbb6e7 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -7,7 +7,7 @@ from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index da2ecfc4bccba..5352b1c5a37cf 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -318,7 +318,7 @@ class Worker(WorkerBase): pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader + from vllm.model_executor.model_loader import ShardedStateLoader ShardedStateLoader.save_model( self.model_runner.model, path, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e22bbcc656ff4..d96021cc688e2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1220,7 +1220,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader + from vllm.model_executor.model_loader import ShardedStateLoader ShardedStateLoader.save_model( self.model, path, @@ -1232,7 +1232,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self, tensorizer_config: TensorizerConfig, ) -> None: - from vllm.model_executor.model_loader.loader import TensorizerLoader + from vllm.model_executor.model_loader import TensorizerLoader TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config,