[Frontend] [Core] Integrate Tensorizer in to S3 loading machinery, allow passing arbitrary arguments during save/load (#19619)

Signed-off-by: Sanger Steel <sangersteel@gmail.com>
Co-authored-by: Eta <esyra@coreweave.com>
This commit is contained in:
Sanger Steel 2025-07-08 01:47:43 -04:00 committed by GitHub
parent e34d130c16
commit 72d14d0eed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 814 additions and 196 deletions

View File

@ -4,6 +4,7 @@
import argparse
import dataclasses
import json
import logging
import os
import uuid
@ -15,9 +16,13 @@ from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig,
tensorize_lora_adapter,
tensorize_vllm_model,
tensorizer_kwargs_arg,
)
from vllm.utils import FlexibleArgumentParser
logger = logging.getLogger()
# yapf conflicts with isort for this docstring
# yapf: disable
"""
@ -119,7 +124,7 @@ vllm serve <model_path> \
"""
def parse_args():
def get_parser():
parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
@ -135,13 +140,13 @@ def parse_args():
required=False,
help="Path to a LoRA adapter to "
"serialize along with model tensors. This can then be deserialized "
"along with the model by passing a tensorizer_config kwarg to "
"LoRARequest with type TensorizerConfig. See the docstring for this "
"for a usage example."
"along with the model by instantiating a TensorizerConfig object, "
"creating a dict from it with TensorizerConfig.to_serializable(), "
"and passing it to LoRARequest's initializer with the kwarg "
"tensorizer_config_dict."
)
subparsers = parser.add_subparsers(dest='command')
subparsers = parser.add_subparsers(dest='command', required=True)
serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
@ -171,6 +176,14 @@ def parse_args():
"where `suffix` is given by `--suffix` or a random UUID if not "
"provided.")
serialize_parser.add_argument(
"--serialization-kwargs",
type=tensorizer_kwargs_arg,
required=False,
help=("A JSON string containing additional keyword arguments to "
"pass to Tensorizer's TensorSerializer during "
"serialization."))
serialize_parser.add_argument(
"--keyfile",
type=str,
@ -186,9 +199,17 @@ def parse_args():
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
required=False,
help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--serialized-directory",
type=str,
required=False,
help="Directory with model artifacts for loading. Assumes a "
"model.tensors file exists therein. Can supersede "
"--path-to-tensors.")
deserialize_parser.add_argument(
"--keyfile",
type=str,
@ -196,11 +217,27 @@ def parse_args():
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))
deserialize_parser.add_argument(
"--deserialization-kwargs",
type=tensorizer_kwargs_arg,
required=False,
help=("A JSON string containing additional keyword arguments to "
"pass to Tensorizer's `TensorDeserializer` during "
"deserialization."))
TensorizerArgs.add_cli_args(deserialize_parser)
return parser.parse_args()
return parser
def merge_extra_config_with_tensorizer_config(extra_cfg: dict,
cfg: TensorizerConfig):
for k, v in extra_cfg.items():
if hasattr(cfg, k):
setattr(cfg, k, v)
logger.info(
"Updating TensorizerConfig with %s from "
"--model-loader-extra-config provided", k
)
def deserialize(args, tensorizer_config):
if args.lora_path:
@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config):
lora_request=LoRARequest("sql-lora",
1,
args.lora_path,
tensorizer_config = tensorizer_config)
tensorizer_config_dict = tensorizer_config
.to_serializable())
)
)
else:
@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config):
def main():
args = parse_args()
parser = get_parser()
args = parser.parse_args()
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
@ -265,13 +304,24 @@ def main():
else:
keyfile = None
extra_config = {}
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = \
TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None
extra_config = json.loads(args.model_loader_extra_config)
tensorizer_dir = (args.serialized_directory or
extra_config.get("tensorizer_dir"))
tensorizer_uri = (getattr(args, "path_to_tensors", None)
or extra_config.get("tensorizer_uri"))
if tensorizer_dir and tensorizer_uri:
parser.error("--serialized-directory and --path-to-tensors "
"cannot both be provided")
if not tensorizer_dir and not tensorizer_uri:
parser.error("Either --serialized-directory or --path-to-tensors "
"must be provided")
if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
@ -281,7 +331,7 @@ def main():
argparse.Namespace(**eng_args_dict)
)
input_dir = args.serialized_directory.rstrip('/')
input_dir = tensorizer_dir.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
if engine_args.tensor_parallel_size > 1:
@ -292,21 +342,29 @@ def main():
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=keyfile,
**credentials)
serialization_kwargs=args.serialization_kwargs or {},
**credentials
)
if args.lora_path:
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
tensorize_lora_adapter(args.lora_path, tensorizer_config)
merge_extra_config_with_tensorizer_config(extra_config,
tensorizer_config)
tensorize_vllm_model(engine_args, tensorizer_config)
elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
tensorizer_dir=args.serialized_directory,
encryption_keyfile=keyfile,
deserialization_kwargs=args.deserialization_kwargs or {},
**credentials
)
merge_extra_config_with_tensorizer_config(extra_config,
tensorizer_config)
deserialize(args, tensorizer_config)
else:
raise ValueError("Either serialize or deserialize must be specified.")

View File

@ -1,6 +1,6 @@
# testing
pytest
tensorizer>=2.9.0
tensorizer==2.10.1
pytest-forked
pytest-asyncio
pytest-rerunfailures

View File

@ -11,7 +11,7 @@ datasets
ray>=2.10.0,<2.45.0
peft
pytest-asyncio
tensorizer>=2.9.0
tensorizer==2.10.1
packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8

View File

@ -1,6 +1,6 @@
# testing
pytest
tensorizer>=2.9.0
tensorizer==2.10.1
pytest-forked
pytest-asyncio
pytest-rerunfailures

View File

@ -739,7 +739,7 @@ tenacity==9.0.0
# via
# lm-eval
# plotly
tensorizer==2.9.0
tensorizer==2.10.1
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn

View File

@ -689,7 +689,7 @@ setup(
install_requires=get_requirements(),
extras_require={
"bench": ["pandas", "datasets"],
"tensorizer": ["tensorizer>=2.9.0"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import json
import os
import tempfile
import openai
@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri):
@pytest.fixture(scope="module")
def server(model_uri, tensorize_model_and_lora):
model_loader_extra_config = {
"tensorizer_uri": model_uri,
}
# In this case, model_uri is a directory with a model.tensors
# file and all necessary model artifacts, particularly a
# HF `config.json` file. In this case, Tensorizer can infer the
# `TensorizerConfig` so --model-loader-extra-config can be completely
# omitted.
## Start OpenAI API server
args = [
"--load-format", "tensorizer", "--device", "cuda",
"--model-loader-extra-config",
json.dumps(model_loader_extra_config), "--enable-lora"
"--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
"--enable-lora"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
model_dir = os.path.dirname(model_uri)
with RemoteOpenAIServer(model_dir, args) as remote_server:
yield remote_server

View File

@ -169,7 +169,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
str(tp_size), "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
],
check=True,
capture_output=True,
@ -195,7 +196,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
tensor_parallel_size=2,
max_loras=2)
tensorizer_config_dict = tensorizer_config.to_dict()
tensorizer_config_dict = tensorizer_config.to_serializable()
print("lora adapter created")
assert do_sample(loaded_vllm_model,

View File

@ -1,9 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable
import pytest
from vllm import LLM, EngineArgs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor.abstract import UniProcExecutor
from vllm.worker.worker_base import WorkerWrapperBase
MODEL_REF = "facebook/opt-125m"
@pytest.fixture()
def model_ref():
return MODEL_REF
@pytest.fixture(autouse=True)
def allow_insecure_serialization(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(autouse=True)
@ -11,7 +30,73 @@ def cleanup():
cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture()
def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path):
def noop(*args, **kwargs):
return None
args = EngineArgs(model=model_ref)
tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors")
monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop)
tensorizer_mod.tensorize_vllm_model(args, tc)
yield tmp_path
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
@pytest.fixture()
def model_path(model_ref, tmp_path):
yield tmp_path / model_ref / "model.tensors"
def assert_from_collective_rpc(engine: LLM, closure: Callable,
closure_kwargs: dict):
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
return all(res)
# This is an object pulled from tests/v1/engine/test_engine_core.py
# Modified to strip the `load_model` method from its `_init_executor`
# method. It's purely used as a dummy utility to run methods that test
# Tensorizer functionality
class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
device_info = self.vllm_config.device_config.device.__str__().split(
":")
if len(device_info) > 1:
local_rank = int(device_info[1])
rank = 0
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
@property
def max_concurrent_batches(self) -> int:
return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)

View File

@ -1,36 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import gc
import json
import os
import pathlib
import subprocess
import sys
from typing import Any
import pytest
import torch
from vllm import SamplingParams
import vllm.model_executor.model_loader.tensorizer
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
open_stream,
tensorize_vllm_model)
from vllm.model_executor.model_loader.tensorizer_loader import (
BLACKLISTED_TENSORIZER_ARGS)
# yapf: enable
from vllm.utils import PlaceholderModule
from ..utils import VLLM_PATH
from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import DummyExecutor, assert_from_collective_rpc
try:
import tensorizer
from tensorizer import EncryptionParams
except ImportError:
tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment]
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
class TensorizerCaughtError(Exception):
pass
EXAMPLES_PATH = VLLM_PATH / "examples"
pytest_plugins = "pytest_asyncio",
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -40,9 +55,37 @@ prompts = [
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join(
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def patch_init_and_catch_error(self, obj, method_name,
expected_error: type[Exception]):
original = getattr(obj, method_name, None)
if original is None:
raise ValueError("Method '{}' not found.".format(method_name))
def wrapper(*args, **kwargs):
try:
return original(*args, **kwargs)
except expected_error as err:
raise TensorizerCaughtError from err
setattr(obj, method_name, wrapper)
self.load_model()
def assert_specific_tensorizer_error_is_raised(
executor,
obj: Any,
method_name: str,
expected_error: type[Exception],
):
with pytest.raises(TensorizerCaughtError):
executor.collective_rpc(patch_init_and_catch_error,
args=(
obj,
method_name,
expected_error,
))
def is_curl_installed():
@ -81,11 +124,10 @@ def test_can_deserialize_s3(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
model_ref, vllm_runner, tmp_path, model_path):
args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
key_path = tmp_path / model_ref / "model.key"
write_keyfile(key_path)
outputs = vllm_model.generate(prompts, sampling_params)
@ -111,9 +153,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path):
tmp_path, model_ref,
model_path):
with hf_runner(model_ref) as hf_model:
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
@ -123,7 +165,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=model_path,
tensorizer_uri=str(model_path),
num_readers=1,
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate_greedy(
@ -132,7 +174,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
assert outputs == deserialized_outputs
def test_load_without_tensorizer_load_format(vllm_runner, capfd):
def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
model = None
try:
model = vllm_runner(
@ -150,7 +192,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd):
torch.cuda.empty_cache()
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd):
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd,
model_ref):
model = None
try:
model = vllm_runner(
@ -208,7 +251,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
outputs = base_model.generate(prompts, sampling_params)
# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
model_path = str(tmp_path / model_ref / "model-%02d.tensors")
key_path = tmp_path / (model_ref + ".key")
tensorizer_config = TensorizerConfig(
@ -242,13 +285,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
@pytest.mark.flaky(reruns=3)
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner,
tmp_path, model_path):
gc.collect()
torch.cuda.empty_cache()
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path))
args = EngineArgs(model=model_ref, device="cuda")
args = EngineArgs(model=model_ref)
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
@ -264,3 +306,243 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
# noqa: E501
assert outputs == deserialized_outputs
def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref):
# For backwards compatibility, ensure Tensorizer can be still be loaded
# for inference by passing the model reference name, not a local/S3 dir,
# and the location of the model tensors
model_dir = just_serialize_model_tensors
extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"}
## Start OpenAI API server
args = [
"--load-format",
"tensorizer",
"--model-loader-extra-config",
json.dumps(extra_config),
]
with RemoteOpenAIServer(model_ref, args):
# This test only concerns itself with being able to load the model
# and successfully initialize the server
pass
def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path):
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
llm = LLM(model=model_ref, )
def serialization_test(self, *args, **kwargs):
# This is performed in the ephemeral worker process, so monkey-patching
# will actually work, and cleanup is guaranteed so don't
# need to reset things
original_dict = serialization_params
to_compare = {}
original = tensorizer.serialization.TensorSerializer.__init__
def tensorizer_serializer_wrapper(self, *args, **kwargs):
nonlocal to_compare
to_compare = kwargs.copy()
return original(self, *args, **kwargs)
tensorizer.serialization.TensorSerializer.__init__ = (
tensorizer_serializer_wrapper)
tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"])
self.save_tensorized_model(tensorizer_config=tensorizer_config, )
return to_compare | original_dict == to_compare
kwargs = {"tensorizer_config": config.to_serializable()}
assert assert_from_collective_rpc(llm, serialization_test, kwargs)
def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(
tmp_path, capfd):
deserialization_kwargs = {
"num_readers": "bar", # illegal value
}
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
loader_tc = TensorizerConfig(
tensorizer_uri=str(model_path),
deserialization_kwargs=deserialization_kwargs,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
load_format="tensorizer",
model_loader_extra_config=loader_tc.to_serializable(),
)
vllm_config = engine_args.create_engine_config()
executor = DummyExecutor(vllm_config)
assert_specific_tensorizer_error_is_raised(
executor,
tensorizer.serialization.TensorDeserializer,
"__init__",
TypeError,
)
def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd):
deserialization_kwargs = {
"num_readers": 1,
}
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
stream_kwargs = {"mode": "foo"}
loader_tc = TensorizerConfig(
tensorizer_uri=str(model_path),
deserialization_kwargs=deserialization_kwargs,
stream_kwargs=stream_kwargs,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
load_format="tensorizer",
model_loader_extra_config=loader_tc.to_serializable(),
)
vllm_config = engine_args.create_engine_config()
executor = DummyExecutor(vllm_config)
assert_specific_tensorizer_error_is_raised(
executor,
vllm.model_executor.model_loader.tensorizer,
"open_stream",
ValueError,
)
@pytest.mark.asyncio
async def test_serialize_and_serve_entrypoints(tmp_path):
model_ref = "facebook/opt-125m"
suffix = "test"
try:
result = subprocess.run([
sys.executable,
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
model_ref, "serialize", "--serialized-directory",
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
'{"limit_cpu_concurrency": 4}'
],
check=True,
capture_output=True,
text=True)
except subprocess.CalledProcessError as e:
print("Tensorizing failed.")
print("STDOUT:\n", e.stdout)
print("STDERR:\n", e.stderr)
raise
assert "Successfully serialized" in result.stdout
# Next, try to serve with vllm serve
model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors"
model_loader_extra_config = {
"tensorizer_uri": str(model_uri),
"stream_kwargs": {
"force_http": False,
},
"deserialization_kwargs": {
"verify_hash": True,
"num_readers": 8,
}
}
cmd = [
"-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost",
"--load-format", "tensorizer", model_ref,
"--model-loader-extra-config",
json.dumps(model_loader_extra_config, indent=2)
]
proc = await asyncio.create_subprocess_exec(
sys.executable,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
fut = proc.stdout.readuntil(b"Application startup complete.")
try:
await asyncio.wait_for(fut, 180)
except asyncio.TimeoutError:
pytest.fail("Server did not start successfully")
finally:
proc.terminate()
await proc.communicate()
@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS)
def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd,
illegal_value):
serialization_params = {
"limit_cpu_concurrency": 2,
}
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path),
serialization_kwargs=serialization_params)
args = EngineArgs(model=model_ref)
tensorize_vllm_model(args, config)
loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"}
try:
vllm_runner(
model_ref,
load_format="tensorizer",
model_loader_extra_config=loader_tc,
)
except RuntimeError:
out, err = capfd.readouterr()
combined_output = out + err
assert (f"ValueError: {illegal_value} is not an allowed "
f"Tensorizer argument.") in combined_output

View File

@ -686,8 +686,11 @@ class ModelConfig:
# If tokenizer is same as model, download to same directory
if model == tokenizer:
s3_model.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
s3_model.pull_files(model,
ignore_pattern=[
"*.pt", "*.safetensors", "*.bin",
"*.tensors"
])
self.tokenizer = s3_model.dir
return
@ -695,7 +698,8 @@ class ModelConfig:
if is_s3(tokenizer):
s3_tokenizer = S3Model()
s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model,
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:

View File

@ -58,7 +58,8 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _parse_type(val: str) -> T:
try:
if return_type is json.loads and not re.match("^{.*}$", val):
if return_type is json.loads and not re.match(
r"(?s)^\s*{.*}\s*$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
@ -80,7 +81,7 @@ def optional_type(
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
if not re.match(r"(?s)^\s*{.*}\s*$", val):
return str(val)
return optional_type(json.loads)(val)
@ -1001,11 +1002,42 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype,
)
def valid_tensorizer_config_provided(self) -> bool:
"""
Checks if a parseable TensorizerConfig was passed to
self.model_loader_extra_config. It first checks if the config passed
is a dict or a TensorizerConfig object directly, and if the latter is
true (by checking that the object has TensorizerConfig's
.to_serializable() method), converts it in to a serializable dict
format
"""
if self.model_loader_extra_config:
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
try:
self.model_loader_extra_config[allowed_to_pass]
return False
except KeyError:
pass
return True
def create_load_config(self) -> LoadConfig:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
if (self.load_format == "tensorizer"
and self.valid_tensorizer_config_provided()):
logger.info("Inferring Tensorizer args from %s", self.model)
self.model_loader_extra_config = {"tensorizer_dir": self.model}
else:
logger.info(
"Using Tensorizer args from --model-loader-extra-config. "
"Note that you can now simply pass the S3 directory in the "
"model tag instead of providing the JSON string.")
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,

View File

@ -245,9 +245,10 @@ class LoRAModel(AdapterModel):
lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
"adapter_model.tensors")
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserializer_params)
tensors = TensorDeserializer(
lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):

View File

@ -106,7 +106,7 @@ class PEFTHelper:
"adapter_config.json")
with open_stream(lora_config_path,
mode="rb",
**tensorizer_args.stream_params) as f:
**tensorizer_args.stream_kwargs) as f:
config = json.load(f)
logger.info("Successfully deserialized LoRA config from %s",

View File

@ -5,18 +5,18 @@ import argparse
import contextlib
import contextvars
import dataclasses
import io
import json
import os
import tempfile
import threading
import time
from collections.abc import Generator
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union
from collections.abc import Generator, MutableMapping
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
import regex as re
import torch
from huggingface_hub import snapshot_download
from torch import nn
from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig
@ -39,10 +39,6 @@ try:
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
_read_stream, _write_stream = (partial(
open_stream,
mode=mode,
) for mode in ("rb", "wb+"))
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
@ -54,9 +50,6 @@ except ImportError:
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
_read_stream = tensorizer.placeholder_attr("_read_stream")
_write_stream = tensorizer.placeholder_attr("_write_stream")
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
@ -66,6 +59,23 @@ __all__ = [
logger = init_logger(__name__)
def is_valid_deserialization_uri(uri: Optional[str]) -> bool:
if uri:
scheme = uri.lower().split("://")[0]
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
return False
def tensorizer_kwargs_arg(value):
loaded = json.loads(value)
if not isinstance(loaded, dict):
raise argparse.ArgumentTypeError(
f"Not deserializable to dict: {value}. serialization_kwargs and "
f"deserialization_kwargs must be "
f"deserializable from a JSON string to a dictionary. ")
return loaded
class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
@ -137,101 +147,45 @@ class _NoInitOrTensorImpl:
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[str, None] = None
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
class TensorizerConfig(MutableMapping):
tensorizer_uri: Optional[str] = None
tensorizer_dir: Optional[str] = None
vllm_tensorized: Optional[bool] = None
verify_hash: Optional[bool] = None
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
lora_dir: Optional[str] = None
_is_sharded: bool = False
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None
if not self.tensorizer_uri and not self.lora_dir:
raise ValueError("tensorizer_uri must be provided.")
if not self.tensorizer_uri and self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
"provided.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
self.lora_dir = self.tensorizer_dir
@classmethod
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
cfg = TensorizerConfig(*args, **kwargs)
return dataclasses.asdict(cfg)
def to_dict(self) -> dict[str, Any]:
return dataclasses.asdict(self)
def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if parallel_config.tensor_parallel_size > 1 \
and not self._is_sharded:
raise ValueError(
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard")
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri,
**tensorizer_args.stream_params)
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int]
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
stream_kwargs: Optional[dict[str, Any]] = None
serialization_kwargs: Optional[dict[str, Any]] = None
deserialization_kwargs: Optional[dict[str, Any]] = None
_extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False,
default=None)
model_class: Optional[type[torch.nn.Module]] = field(init=False,
default=None)
hf_config: Optional[PretrainedConfig] = field(init=False, default=None)
dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None)
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""
Args for the TensorizerAgent class. These are used to configure the behavior
of the TensorDeserializer when loading tensors from a serialized model.
Args:
Args for the TensorizerConfig class. These are used to configure the
behavior of model serialization and deserialization using Tensorizer.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function.
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is a required
argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of `tensorizer_uri`
where the `model.tensors` file will be assumed to be in this
directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
@ -256,34 +210,174 @@ class TensorizerArgs:
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self):
self.file_obj = self.tensorizer_uri
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
self.s3_secret_access_key = (self.s3_secret_access_key
# check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None
if self.tensorizer_dir and self.tensorizer_uri:
raise ValueError(
"Either tensorizer_dir or tensorizer_uri must be provided, "
"not both.")
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise.")
if not self.tensorizer_uri:
if self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
elif self.tensorizer_dir:
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
else:
raise ValueError("Unable to resolve tensorizer_uri. "
"A valid tensorizer_uri or tensorizer_dir "
"must be provided for deserialization, and a "
"valid tensorizer_uri, tensorizer_uri, or "
"lora_dir for serialization.")
else:
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.serialization_kwargs:
self.serialization_kwargs = {}
if not self.deserialization_kwargs:
self.deserialization_kwargs = {}
def to_serializable(self) -> dict[str, Any]:
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
# support for morphing back and forth between itself and its dict
# representation
# TensorizerConfig's representation as a dictionary is meant to be
# linked to TensorizerConfig in such a way that the following is
# technically initializable:
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
# This means the dict must not retain non-initializable parameters
# and post-init attribute states
# Also don't want to retain private and unset parameters, so only retain
# not None values and public attributes
raw_tc_dict = asdict(self)
blacklisted = []
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
tc_dict = {}
for k, v in raw_tc_dict.items():
if (k not in blacklisted and k not in tc_dict
and not k.startswith("_") and v is not None):
tc_dict[k] = v
return tc_dict
def _construct_tensorizer_args(self) -> "TensorizerArgs":
return TensorizerArgs(self) # type: ignore
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if parallel_config.tensor_parallel_size > 1 \
and not self._is_sharded:
raise ValueError(
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard")
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri,
**tensorizer_args.stream_kwargs)
def keys(self):
return self._keys
def __len__(self):
return len(fields(self))
def __iter__(self):
return iter(self._fields)
def __getitem__(self, item: str) -> Any:
if item not in self.keys():
raise KeyError(item)
return getattr(self, item)
def __setitem__(self, key: str, value: Any) -> None:
if key not in self.keys():
# Disallow modifying invalid keys
raise KeyError(key)
setattr(self, key, value)
def __delitem__(self, key, /):
if key not in self.keys():
raise KeyError(key)
delattr(self, key)
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
@dataclass
class TensorizerArgs:
tensorizer_uri: Optional[str] = None
tensorizer_dir: Optional[str] = None
encryption_keyfile: Optional[str] = None
def __init__(self, tensorizer_config: TensorizerConfig):
for k, v in tensorizer_config.items():
setattr(self, k, v)
self.file_obj = tensorizer_config.tensorizer_uri
self.s3_access_key_id = (tensorizer_config.s3_access_key_id
or envs.S3_ACCESS_KEY_ID)
self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key
or envs.S3_SECRET_ACCESS_KEY)
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_params = {
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_kwargs = {
"s3_access_key_id": tensorizer_config.s3_access_key_id,
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
"s3_endpoint": tensorizer_config.s3_endpoint,
**(tensorizer_config.stream_kwargs or {})
}
self.deserializer_params = {
"verify_hash": self.verify_hash,
"encryption": self.encryption_keyfile,
"num_readers": self.num_readers
self.deserialization_kwargs = {
"verify_hash": tensorizer_config.verify_hash,
"encryption": tensorizer_config.encryption_keyfile,
"num_readers": tensorizer_config.num_readers,
**(tensorizer_config.deserialization_kwargs or {})
}
if self.encryption_keyfile:
with open_stream(
self.encryption_keyfile,
**self.stream_params,
tensorizer_config.encryption_keyfile,
**self.stream_kwargs,
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params
self.deserialization_kwargs['encryption'] = decryption_params
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
@ -405,15 +499,22 @@ def init_tensorizer_model(tensorizer_config: TensorizerConfig,
def deserialize_tensorizer_model(model: nn.Module,
tensorizer_config: TensorizerConfig) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme.")
before_mem = get_mem_usage()
start = time.perf_counter()
with _read_stream(
with open_stream(
tensorizer_config.tensorizer_uri,
**tensorizer_args.stream_params) as stream, TensorDeserializer(
mode="rb",
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}',
**tensorizer_args.deserializer_params) as deserializer:
device=torch.device("cuda", torch.cuda.current_device()),
**tensorizer_args.deserialization_kwargs) as deserializer:
deserializer.load_into_module(model)
end = time.perf_counter()
@ -442,9 +543,9 @@ def tensorizer_weights_iterator(
"examples/others/tensorize_vllm_model.py example script "
"for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
deserializer_args = tensorizer_args.deserialization_kwargs
stream_kwargs = tensorizer_args.stream_kwargs
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
yield from state.items()
@ -465,8 +566,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"""
tensorizer_args = tensorizer_config._construct_tensorizer_args()
deserializer = TensorDeserializer(open_stream(
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
**tensorizer_args.deserializer_params,
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
**tensorizer_args.deserialization_kwargs,
lazy_load=True)
if tensorizer_config.vllm_tensorized:
logger.warning(
@ -477,13 +578,41 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
return ".vllm_tensorized_marker" in deserializer
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs,
served_model_name: Union[str, list[str], None]) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}.")
with tempfile.TemporaryDirectory() as tmpdir:
snapshot_download(served_model_name,
local_dir=tmpdir,
ignore_patterns=[
"*.pt", "*.safetensors", "*.bin", "*.cache",
"*.gitattributes", "*.md"
])
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with open(artifact.path, "rb") as f, open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs) as stream:
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())
def serialize_vllm_model(
model: nn.Module,
tensorizer_config: TensorizerConfig,
model_config: "ModelConfig",
) -> nn.Module:
model.register_parameter(
"vllm_tensorized_marker",
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None
@ -497,10 +626,16 @@ def serialize_vllm_model(
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
with open_stream(output_file, mode="wb+",
**tensorizer_args.stream_kwargs) as stream:
serializer = TensorSerializer(stream,
encryption=encryption_params,
**tensorizer_config.serialization_kwargs)
serializer.write_module(model)
serializer.close()
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
logger.info("Successfully serialized model to %s", str(output_file))
return model
@ -522,8 +657,9 @@ def tensorize_vllm_model(engine_args: "EngineArgs",
if generate_keyfile and (keyfile :=
tensorizer_config.encryption_keyfile) is not None:
encryption_params = EncryptionParams.random()
with _write_stream(
with open_stream(
keyfile,
mode="wb+",
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
@ -537,13 +673,13 @@ def tensorize_vllm_model(engine_args: "EngineArgs",
engine = LLMEngine.from_engine_args(engine_args)
engine.model_executor.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
)
else:
engine = V1LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
)
@ -586,14 +722,14 @@ def tensorize_lora_adapter(lora_path: str,
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
mode="wb+",
**tensorizer_args.stream_params) as f:
**tensorizer_args.stream_kwargs) as f:
f.write(json.dumps(config).encode("utf-8"))
lora_uri = (f"{tensorizer_config.lora_dir}"
f"/adapter_model.tensors")
with open_stream(lora_uri, mode="wb+",
**tensorizer_args.stream_params) as f:
**tensorizer_args.stream_kwargs) as f:
serializer = TensorSerializer(f)
serializer.write_state_dict(tensors)
serializer.close()

View File

@ -20,6 +20,18 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
logger = init_logger(__name__)
BLACKLISTED_TENSORIZER_ARGS = {
"device", # vLLM decides this
"dtype", # vLLM decides this
"mode", # Not meant to be configurable by the user
}
def validate_config(config: dict):
for k, v in config.items():
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
@ -29,6 +41,7 @@ class TensorizerLoader(BaseModelLoader):
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
validate_config(load_config.model_loader_extra_config)
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
@ -118,10 +131,12 @@ class TensorizerLoader(BaseModelLoader):
def save_model(
model: torch.nn.Module,
tensorizer_config: Union[TensorizerConfig, dict],
model_config: ModelConfig,
) -> None:
if isinstance(tensorizer_config, dict):
tensorizer_config = TensorizerConfig(**tensorizer_config)
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
model_config=model_config,
)

View File

@ -1820,6 +1820,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
TensorizerLoader.save_model(
self.model,
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
def _get_prompt_logprobs_dict(

View File

@ -1246,6 +1246,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
TensorizerLoader.save_model(
self.model,
tensorizer_config=tensorizer_config,
model_config=self.model_config,
)
def get_max_block_per_batch(self) -> int: