# SPDX-License-Identifier: Apache-2.0 # ruff: noqa: SIM117 import copy from collections.abc import Generator import torch from torch import nn from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (get_model_architecture, initialize_model, set_default_torch_dtype) logger = init_logger(__name__) class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" def __init__(self, load_config: LoadConfig): super().__init__(load_config) if isinstance(load_config.model_loader_extra_config, TensorizerConfig): self.tensorizer_config = load_config.model_loader_extra_config else: self.tensorizer_config = TensorizerConfig( **load_config.model_loader_extra_config) def _verify_config(self, model_config: ModelConfig, parallel_config: ParallelConfig): self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( self, ) -> Generator[tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) def _load_model_serialized_cpu( self, vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. This is only necessary when the model isn't vLLM-tensorized (see examples/other/tensorize_vllm_model.py) This should still be faster than default HuggingFace loading, but will be slower than loading a vLLM-tensorized model. """ device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = initialize_model(vllm_config=vllm_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( self, vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer. Expects a vLLM-tensorized model. See the examples/other/tensorize_vllm_model.py example script for serializing vLLM models.""" device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class tensorizer_config.hf_config = model_config.hf_config tensorizer_config.dtype = model_config.dtype model = load_with_tensorizer(tensorizer_config, vllm_config=vllm_config) return model.eval() def download_model(self, model_config: ModelConfig) -> None: self.tensorizer_config.verify_with_model_config(model_config) with self.tensorizer_config.open_stream(): pass def load_model(self, vllm_config: VllmConfig) -> nn.Module: model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) if parallel_config.tensor_parallel_size > 1: from vllm.distributed import get_tensor_model_parallel_rank self.tensorizer_config.tensorizer_uri = ( self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()) if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(vllm_config=vllm_config) return self._load_model_serialized_cpu(vllm_config=vllm_config) @staticmethod def save_model( model: torch.nn.Module, tensorizer_config: TensorizerConfig, ) -> None: serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, )