[Core] Support inplace model weights loading (#18745)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-06-02 02:38:50 -07:00 committed by GitHub
parent b9f61e1387
commit 9760fd8f6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 240 additions and 288 deletions

View File

@ -4,7 +4,6 @@ import gc
import os import os
import pathlib import pathlib
import subprocess import subprocess
from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
@ -16,7 +15,6 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer, TensorSerializer,
is_vllm_tensorized, is_vllm_tensorized,
load_with_tensorizer,
open_stream, open_stream,
tensorize_vllm_model) tensorize_vllm_model)
# yapf: enable # yapf: enable
@ -61,21 +59,6 @@ def write_keyfile(keyfile_path: str):
f.write(encryption_params.key) f.write(encryption_params.key)
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
quant_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner): def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b" model_ref = "EleutherAI/pythia-1.4b"

View File

@ -94,6 +94,9 @@ def model_runner():
return runner return runner
model_runner_2 = model_runner
def _schedule_new_request(*req_ids: str) -> SchedulerOutput: def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = [] new_reqs = []
num_scheduled_tokens = {} num_scheduled_tokens = {}
@ -366,3 +369,18 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert all(kv.is_contiguous() for kv in model_runner.kv_caches) assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
else: else:
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy"
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
class BaseModelLoader(ABC): class BaseModelLoader(ABC):
@ -18,7 +21,22 @@ class BaseModelLoader(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load_model(self, *, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load weights into a model. This standalone API allows
inplace weights loading for an already-initialized model"""
raise NotImplementedError
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
raise NotImplementedError device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model.eval()

View File

@ -14,7 +14,7 @@ from huggingface_hub import HfApi
from torch import nn from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
# yapf: enable # yapf: enable
@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import (LinearBase,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (ParamMapping, from vllm.model_executor.model_loader.utils import (ParamMapping,
initialize_model,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
@ -408,8 +407,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
), "vllm currently does not support BNB quantization for" ), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}" f" {type(model).__name__}"
def _load_weights(self, model_config: ModelConfig, def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
model: nn.Module) -> None:
if not hasattr(model, "load_weights"): if not hasattr(model, "load_weights"):
raise AttributeError( raise AttributeError(
"The required method 'load_weights' is not defined in class" "The required method 'load_weights' is not defined in class"
@ -568,15 +566,3 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
self._load_weights(model_config, model)
return model.eval()

View File

@ -12,11 +12,9 @@ from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.config import LoadConfig, LoadFormat, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
@ -264,15 +262,8 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt=True, fall_back_to_pt=True,
allow_patterns_overrides=None) allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
weights_to_load = {name for name, _ in model.named_parameters()} weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights( loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)) self.get_all_weights(model_config, model))
@ -286,10 +277,5 @@ class DefaultModelLoader(BaseModelLoader):
if model_config.quantization is None and loaded_weights is not None: if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded: if weights_not_loaded:
raise ValueError( raise ValueError("Following weights were not initialized from "
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}") f"checkpoint: {weights_not_loaded}")
process_weights_after_loading(model, model_config, target_device)
return model.eval()

View File

@ -1,11 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
initialize_dummy_weights) initialize_dummy_weights)
@ -22,16 +19,8 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download pass # Nothing to download
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
process_weights_after_loading(model, model_config, target_device)
return model.eval()

View File

@ -92,6 +92,13 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model) self._prepare_weights(model_config.model)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
def load_model(self, vllm_config: VllmConfig, def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
@ -106,8 +113,7 @@ class GGUFModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model(vllm_config=vllm_config) model = initialize_model(vllm_config=vllm_config)
model.load_weights( self.load_weights(model, model_config)
self._get_weights_iterator(local_model_path, gguf_weights_map))
process_weights_after_loading(model, model_config, target_device) process_weights_after_loading(model, model_config, target_device)
return model return model

View File

@ -9,10 +9,8 @@ import torch
from torch import nn from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
runai_safetensors_weights_iterator) runai_safetensors_weights_iterator)
@ -100,21 +98,11 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary""" """Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
"""Perform streaming of the model to destination""" """Load weights into a model."""
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model_weights = model_config.model model_weights = model_config.model
if hasattr(model_config, "model_weights"): if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights model_weights = model_config.model_weights
model.load_weights( model.load_weights(
self._get_weights_iterator(model_weights, self._get_weights_iterator(model_weights, model_config.revision))
model_config.revision))
process_weights_after_loading(model, model_config, target_device)
return model.eval()

View File

@ -9,11 +9,9 @@ from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator) download_weights_from_hf, runai_safetensors_weights_iterator)
from vllm.transformers_utils.s3_utils import glob as s3_glob from vllm.transformers_utils.s3_utils import glob as s3_glob
@ -100,11 +98,8 @@ class ShardedStateLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig, def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> None:
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model model_weights = model_config.model
@ -112,11 +107,6 @@ class ShardedStateLoader(BaseModelLoader):
model_weights = model_config.model_weights model_weights = model_config.model_weights
local_model_path = model_weights local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
process_weights_after_loading(model, model_config,
target_device)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
pattern = os.path.join( pattern = os.path.join(
local_model_path, local_model_path,
@ -158,7 +148,6 @@ class ShardedStateLoader(BaseModelLoader):
if state_dict: if state_dict:
raise ValueError( raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!") f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()
def iterate_over_files( def iterate_over_files(
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:

View File

@ -21,7 +21,8 @@ from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
set_current_vllm_config)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -208,12 +209,6 @@ class TensorizerConfig:
**tensorizer_args.stream_params) **tensorizer_args.stream_params)
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module:
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
return tensorizer.deserialize()
@dataclass @dataclass
class TensorizerArgs: class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
@ -366,40 +361,22 @@ class TensorizerArgs:
return tensorizer_args return tensorizer_args
class TensorizerAgent: def _check_tensors_on_meta_device(model: nn.Module) -> None:
""" for tensor in model.state_dict().values():
A class for performing tensorizer deserializations specifically for if tensor.device.type == 'meta':
vLLM models using plaid_mode. Uses TensorizerArgs to configure the raise ValueError(
behavior of the TensorDeserializer when loading tensors from a serialized "The serialized model contains tensors on the meta device,"
model. For deserializations of HuggingFace models, TensorDeserializer is " indicating that some tensors were not loaded properly."
instead used as an iterator directly in the func hf_model_weights_iterator " Please check that the parameters of the model being"
in vllm/model_executor/model_loader/weight_utils.py " specified match that of the serialized model, such as"
""" " its quantization.")
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
self.vllm_config = vllm_config
self.model = self._init_model()
def _init_model(self): def _resize_lora_embeddings(model: nn.Module):
assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(self.vllm_config,
check_compile=True):
return self.tensorizer_config.model_class(
vllm_config=self.vllm_config)
def _resize_lora_embeddings(self):
"""Modify LoRA embedding layers to use bigger tensors """Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens.""" to allow for adapter added tokens."""
for child in self.model.modules(): for child in model.modules():
if (isinstance(child, VocabParallelEmbedding) if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0]
and child.weight.shape[0]
< child.num_embeddings_per_partition): < child.num_embeddings_per_partition):
new_weight = torch.empty(child.num_embeddings_per_partition, new_weight = torch.empty(child.num_embeddings_per_partition,
child.embedding_dim, child.embedding_dim,
@ -409,41 +386,32 @@ class TensorizerAgent:
new_weight[child.weight.shape[0]:].fill_(0) new_weight[child.weight.shape[0]:].fill_(0)
child.weight.data = new_weight child.weight.data = new_weight
def _check_tensors_on_meta_device(self):
for tensor in self.model.state_dict().values():
if tensor.device.type == 'meta':
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization.")
def deserialize(self): def init_tensorizer_model(tensorizer_config: TensorizerConfig,
""" vllm_config: VllmConfig) -> nn.Module:
Deserialize the model using the TensorDeserializer. This method is assert tensorizer_config.hf_config is not None
specifically for vLLM models using tensorizer's plaid_mode. model_args = tensorizer_config.hf_config
model_args.torch_dtype = tensorizer_config.dtype
assert tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(vllm_config,
check_compile=True):
return tensorizer_config.model_class(vllm_config=vllm_config)
The deserializer makes use of tensorizer_args.stream_params
to configure the behavior of the stream when loading tensors from a
serialized model. The deserializer_params are used to configure the
behavior of the TensorDeserializer when loading tensors themselves.
Documentation on these params can be found in TensorizerArgs
Returns: def deserialize_tensorizer_model(model: nn.Module,
nn.Module: The deserialized model. tensorizer_config: TensorizerConfig) -> None:
""" tensorizer_args = tensorizer_config._construct_tensorizer_args()
before_mem = get_mem_usage() before_mem = get_mem_usage()
start = time.perf_counter() start = time.perf_counter()
with _read_stream( with _read_stream(
self.tensorizer_config.tensorizer_uri, tensorizer_config.tensorizer_uri,
**self.tensorizer_args.stream_params **tensorizer_args.stream_params) as stream, TensorDeserializer(
) as stream, TensorDeserializer(
stream, stream,
dtype=self.tensorizer_config.dtype, dtype=tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}', device=f'cuda:{torch.cuda.current_device()}',
**self.tensorizer_args.deserializer_params) as deserializer: **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model) deserializer.load_into_module(model)
end = time.perf_counter() end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
@ -456,10 +424,9 @@ class TensorizerAgent:
logger.info("Memory usage before: %s", before_mem) logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem) logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device() _check_tensors_on_meta_device(model)
self._resize_lora_embeddings() _resize_lora_embeddings(model)
del self.model.vllm_tensorized_marker del model.vllm_tensorized_marker
return self.model.eval()
def tensorizer_weights_iterator( def tensorizer_weights_iterator(

View File

@ -11,8 +11,8 @@ from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model,
serialize_vllm_model, tensorizer_weights_iterator) is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture, from vllm.model_executor.model_loader.utils import (get_model_architecture,
initialize_model, initialize_model,
set_default_torch_dtype) set_default_torch_dtype)
@ -61,38 +61,34 @@ class TensorizerLoader(BaseModelLoader):
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
return model.eval() return model.eval()
def _load_model_serialized(
self,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config,
vllm_config=vllm_config)
return model.eval()
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream(): with self.tensorizer_config.open_stream():
pass pass
def _patch_tensorizer_config(
self, model_config: ModelConfig) -> TensorizerConfig:
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
return tensorizer_config
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
def load_model(self, vllm_config: VllmConfig, def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module: model_config: ModelConfig) -> nn.Module:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
@ -106,7 +102,11 @@ class TensorizerLoader(BaseModelLoader):
get_tensor_model_parallel_rank()) get_tensor_model_parallel_rank())
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(vllm_config=vllm_config) tensorizer_config = self._patch_tensorizer_config(model_config)
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config) return self._load_model_serialized_cpu(vllm_config=vllm_config)
@staticmethod @staticmethod

View File

@ -28,7 +28,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context) set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
@ -1564,7 +1564,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
self.model = get_model(vllm_config=self.vllm_config) model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.model = self.load_lora_model(self.model,
self.model_config, self.model_config,

View File

@ -21,7 +21,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
@ -171,7 +171,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
# Lazy initialization # Lazy initialization
# self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@ -419,7 +419,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
assert self.model is not None
return self.model return self.model
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
@ -936,7 +935,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"vllm.model_executor.layers.vocab_parallel_embedding." "vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank", "get_tensor_model_parallel_rank",
return_value=xm_tp_rank): return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config) # model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
if self.lora_config is not None: if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config, model = self.load_lora_model(model, self.model_config,
self.scheduler_config, self.scheduler_config,
@ -947,6 +957,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# loading. # loading.
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
if not hasattr(self, "model"):
self.model = model self.model = model
self.sampler = TPUSampler() self.sampler = TPUSampler()