mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:15:48 +08:00
[Frontend] [Core] Integrate Tensorizer in to S3 loading machinery, allow passing arbitrary arguments during save/load (#19619)
Signed-off-by: Sanger Steel <sangersteel@gmail.com> Co-authored-by: Eta <esyra@coreweave.com>
This commit is contained in:
parent
e34d130c16
commit
72d14d0eed
@ -4,6 +4,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
@ -15,9 +16,13 @@ from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig,
|
||||
tensorize_lora_adapter,
|
||||
tensorize_vllm_model,
|
||||
tensorizer_kwargs_arg,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
"""
|
||||
@ -119,7 +124,7 @@ vllm serve <model_path> \
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
def get_parser():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
@ -135,13 +140,13 @@ def parse_args():
|
||||
required=False,
|
||||
help="Path to a LoRA adapter to "
|
||||
"serialize along with model tensors. This can then be deserialized "
|
||||
"along with the model by passing a tensorizer_config kwarg to "
|
||||
"LoRARequest with type TensorizerConfig. See the docstring for this "
|
||||
"for a usage example."
|
||||
|
||||
"along with the model by instantiating a TensorizerConfig object, "
|
||||
"creating a dict from it with TensorizerConfig.to_serializable(), "
|
||||
"and passing it to LoRARequest's initializer with the kwarg "
|
||||
"tensorizer_config_dict."
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
subparsers = parser.add_subparsers(dest='command', required=True)
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||
@ -171,6 +176,14 @@ def parse_args():
|
||||
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||
"provided.")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--serialization-kwargs",
|
||||
type=tensorizer_kwargs_arg,
|
||||
required=False,
|
||||
help=("A JSON string containing additional keyword arguments to "
|
||||
"pass to Tensorizer's TensorSerializer during "
|
||||
"serialization."))
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
@ -186,9 +199,17 @@ def parse_args():
|
||||
deserialize_parser.add_argument(
|
||||
"--path-to-tensors",
|
||||
type=str,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Directory with model artifacts for loading. Assumes a "
|
||||
"model.tensors file exists therein. Can supersede "
|
||||
"--path-to-tensors.")
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
@ -196,11 +217,27 @@ def parse_args():
|
||||
help=("Path to a binary key to use to decrypt the model weights,"
|
||||
" if the model was serialized with encryption"))
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--deserialization-kwargs",
|
||||
type=tensorizer_kwargs_arg,
|
||||
required=False,
|
||||
help=("A JSON string containing additional keyword arguments to "
|
||||
"pass to Tensorizer's `TensorDeserializer` during "
|
||||
"deserialization."))
|
||||
|
||||
TensorizerArgs.add_cli_args(deserialize_parser)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
return parser
|
||||
|
||||
def merge_extra_config_with_tensorizer_config(extra_cfg: dict,
|
||||
cfg: TensorizerConfig):
|
||||
for k, v in extra_cfg.items():
|
||||
if hasattr(cfg, k):
|
||||
setattr(cfg, k, v)
|
||||
logger.info(
|
||||
"Updating TensorizerConfig with %s from "
|
||||
"--model-loader-extra-config provided", k
|
||||
)
|
||||
|
||||
def deserialize(args, tensorizer_config):
|
||||
if args.lora_path:
|
||||
@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config):
|
||||
lora_request=LoRARequest("sql-lora",
|
||||
1,
|
||||
args.lora_path,
|
||||
tensorizer_config = tensorizer_config)
|
||||
tensorizer_config_dict = tensorizer_config
|
||||
.to_serializable())
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config):
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
|
||||
or os.environ.get("S3_ACCESS_KEY_ID", None))
|
||||
@ -265,13 +304,24 @@ def main():
|
||||
else:
|
||||
keyfile = None
|
||||
|
||||
extra_config = {}
|
||||
if args.model_loader_extra_config:
|
||||
config = json.loads(args.model_loader_extra_config)
|
||||
tensorizer_args = \
|
||||
TensorizerConfig(**config)._construct_tensorizer_args()
|
||||
tensorizer_args.tensorizer_uri = args.path_to_tensors
|
||||
else:
|
||||
tensorizer_args = None
|
||||
extra_config = json.loads(args.model_loader_extra_config)
|
||||
|
||||
|
||||
tensorizer_dir = (args.serialized_directory or
|
||||
extra_config.get("tensorizer_dir"))
|
||||
tensorizer_uri = (getattr(args, "path_to_tensors", None)
|
||||
or extra_config.get("tensorizer_uri"))
|
||||
|
||||
if tensorizer_dir and tensorizer_uri:
|
||||
parser.error("--serialized-directory and --path-to-tensors "
|
||||
"cannot both be provided")
|
||||
|
||||
if not tensorizer_dir and not tensorizer_uri:
|
||||
parser.error("Either --serialized-directory or --path-to-tensors "
|
||||
"must be provided")
|
||||
|
||||
|
||||
if args.command == "serialize":
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
@ -281,7 +331,7 @@ def main():
|
||||
argparse.Namespace(**eng_args_dict)
|
||||
)
|
||||
|
||||
input_dir = args.serialized_directory.rstrip('/')
|
||||
input_dir = tensorizer_dir.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||
if engine_args.tensor_parallel_size > 1:
|
||||
@ -292,21 +342,29 @@ def main():
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
encryption_keyfile=keyfile,
|
||||
**credentials)
|
||||
serialization_kwargs=args.serialization_kwargs or {},
|
||||
**credentials
|
||||
)
|
||||
|
||||
if args.lora_path:
|
||||
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
|
||||
tensorize_lora_adapter(args.lora_path, tensorizer_config)
|
||||
|
||||
merge_extra_config_with_tensorizer_config(extra_config,
|
||||
tensorizer_config)
|
||||
tensorize_vllm_model(engine_args, tensorizer_config)
|
||||
|
||||
elif args.command == "deserialize":
|
||||
if not tensorizer_args:
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=args.path_to_tensors,
|
||||
encryption_keyfile = keyfile,
|
||||
**credentials
|
||||
)
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=args.path_to_tensors,
|
||||
tensorizer_dir=args.serialized_directory,
|
||||
encryption_keyfile=keyfile,
|
||||
deserialization_kwargs=args.deserialization_kwargs or {},
|
||||
**credentials
|
||||
)
|
||||
|
||||
merge_extra_config_with_tensorizer_config(extra_config,
|
||||
tensorizer_config)
|
||||
deserialize(args, tensorizer_config)
|
||||
else:
|
||||
raise ValueError("Either serialize or deserialize must be specified.")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# testing
|
||||
pytest
|
||||
tensorizer>=2.9.0
|
||||
tensorizer==2.10.1
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
pytest-rerunfailures
|
||||
|
||||
@ -11,7 +11,7 @@ datasets
|
||||
ray>=2.10.0,<2.45.0
|
||||
peft
|
||||
pytest-asyncio
|
||||
tensorizer>=2.9.0
|
||||
tensorizer==2.10.1
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# testing
|
||||
pytest
|
||||
tensorizer>=2.9.0
|
||||
tensorizer==2.10.1
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
pytest-rerunfailures
|
||||
|
||||
@ -739,7 +739,7 @@ tenacity==9.0.0
|
||||
# via
|
||||
# lm-eval
|
||||
# plotly
|
||||
tensorizer==2.9.0
|
||||
tensorizer==2.10.1
|
||||
# via -r requirements/test.in
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
|
||||
2
setup.py
2
setup.py
@ -689,7 +689,7 @@ setup(
|
||||
install_requires=get_requirements(),
|
||||
extras_require={
|
||||
"bench": ["pandas", "datasets"],
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import openai
|
||||
@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_uri, tensorize_model_and_lora):
|
||||
model_loader_extra_config = {
|
||||
"tensorizer_uri": model_uri,
|
||||
}
|
||||
# In this case, model_uri is a directory with a model.tensors
|
||||
# file and all necessary model artifacts, particularly a
|
||||
# HF `config.json` file. In this case, Tensorizer can infer the
|
||||
# `TensorizerConfig` so --model-loader-extra-config can be completely
|
||||
# omitted.
|
||||
|
||||
## Start OpenAI API server
|
||||
args = [
|
||||
"--load-format", "tensorizer", "--device", "cuda",
|
||||
"--model-loader-extra-config",
|
||||
json.dumps(model_loader_extra_config), "--enable-lora"
|
||||
"--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
|
||||
"--enable-lora"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
model_dir = os.path.dirname(model_uri)
|
||||
with RemoteOpenAIServer(model_dir, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
|
||||
@ -169,7 +169,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
|
||||
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
|
||||
MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
|
||||
str(tp_size), "serialize", "--serialized-directory",
|
||||
str(tmp_path), "--suffix", suffix
|
||||
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
|
||||
'{"limit_cpu_concurrency": 4}'
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
@ -195,7 +196,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
|
||||
tensor_parallel_size=2,
|
||||
max_loras=2)
|
||||
|
||||
tensorizer_config_dict = tensorizer_config.to_dict()
|
||||
tensorizer_config_dict = tensorizer_config.to_serializable()
|
||||
|
||||
print("lora adapter created")
|
||||
assert do_sample(loaded_vllm_model,
|
||||
|
||||
@ -1,9 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import UniProcExecutor
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
MODEL_REF = "facebook/opt-125m"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model_ref():
|
||||
return MODEL_REF
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def allow_insecure_serialization(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -11,7 +30,73 @@ def cleanup():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path):
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
args = EngineArgs(model=model_ref)
|
||||
tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors")
|
||||
|
||||
monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop)
|
||||
|
||||
tensorizer_mod.tensorize_vllm_model(args, tc)
|
||||
yield tmp_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tensorizer_config():
|
||||
config = TensorizerConfig(tensorizer_uri="vllm")
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model_path(model_ref, tmp_path):
|
||||
yield tmp_path / model_ref / "model.tensors"
|
||||
|
||||
|
||||
def assert_from_collective_rpc(engine: LLM, closure: Callable,
|
||||
closure_kwargs: dict):
|
||||
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
|
||||
return all(res)
|
||||
|
||||
|
||||
# This is an object pulled from tests/v1/engine/test_engine_core.py
|
||||
# Modified to strip the `load_model` method from its `_init_executor`
|
||||
# method. It's purely used as a dummy utility to run methods that test
|
||||
# Tensorizer functionality
|
||||
class DummyExecutor(UniProcExecutor):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||
rpc_rank=0)
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
local_rank = 0
|
||||
# set local rank as the device index if specified
|
||||
device_info = self.vllm_config.device_config.device.__str__().split(
|
||||
":")
|
||||
if len(device_info) > 1:
|
||||
local_rank = int(device_info[1])
|
||||
rank = 0
|
||||
is_driver_worker = True
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||
self.collective_rpc("init_device")
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, 'thread_pool'):
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
|
||||
@ -1,36 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
import vllm.model_executor.model_loader.tensorizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||
TensorSerializer,
|
||||
is_vllm_tensorized,
|
||||
open_stream,
|
||||
tensorize_vllm_model)
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import (
|
||||
BLACKLISTED_TENSORIZER_ARGS)
|
||||
# yapf: enable
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from ..utils import VLLM_PATH
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from .conftest import DummyExecutor, assert_from_collective_rpc
|
||||
|
||||
try:
|
||||
import tensorizer
|
||||
from tensorizer import EncryptionParams
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment]
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
|
||||
|
||||
class TensorizerCaughtError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
EXAMPLES_PATH = VLLM_PATH / "examples"
|
||||
|
||||
pytest_plugins = "pytest_asyncio",
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -40,9 +55,37 @@ prompts = [
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
model_ref = "facebook/opt-125m"
|
||||
tensorize_model_for_testing_script = os.path.join(
|
||||
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
|
||||
|
||||
def patch_init_and_catch_error(self, obj, method_name,
|
||||
expected_error: type[Exception]):
|
||||
original = getattr(obj, method_name, None)
|
||||
if original is None:
|
||||
raise ValueError("Method '{}' not found.".format(method_name))
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return original(*args, **kwargs)
|
||||
except expected_error as err:
|
||||
raise TensorizerCaughtError from err
|
||||
|
||||
setattr(obj, method_name, wrapper)
|
||||
|
||||
self.load_model()
|
||||
|
||||
|
||||
def assert_specific_tensorizer_error_is_raised(
|
||||
executor,
|
||||
obj: Any,
|
||||
method_name: str,
|
||||
expected_error: type[Exception],
|
||||
):
|
||||
with pytest.raises(TensorizerCaughtError):
|
||||
executor.collective_rpc(patch_init_and_catch_error,
|
||||
args=(
|
||||
obj,
|
||||
method_name,
|
||||
expected_error,
|
||||
))
|
||||
|
||||
|
||||
def is_curl_installed():
|
||||
@ -81,11 +124,10 @@ def test_can_deserialize_s3(vllm_runner):
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
vllm_runner, tmp_path):
|
||||
model_ref, vllm_runner, tmp_path, model_path):
|
||||
args = EngineArgs(model=model_ref)
|
||||
with vllm_runner(model_ref) as vllm_model:
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
key_path = tmp_path / (model_ref + ".key")
|
||||
key_path = tmp_path / model_ref / "model.key"
|
||||
write_keyfile(key_path)
|
||||
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
@ -111,9 +153,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
|
||||
|
||||
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
tmp_path):
|
||||
tmp_path, model_ref,
|
||||
model_path):
|
||||
with hf_runner(model_ref) as hf_model:
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
max_tokens = 50
|
||||
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
@ -123,7 +165,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
with vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
tensorizer_uri=str(model_path),
|
||||
num_readers=1,
|
||||
)) as loaded_hf_model:
|
||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||
@ -132,7 +174,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
def test_load_without_tensorizer_load_format(vllm_runner, capfd):
|
||||
def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref):
|
||||
model = None
|
||||
try:
|
||||
model = vllm_runner(
|
||||
@ -150,7 +192,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd):
|
||||
def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd,
|
||||
model_ref):
|
||||
model = None
|
||||
try:
|
||||
model = vllm_runner(
|
||||
@ -208,7 +251,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
|
||||
outputs = base_model.generate(prompts, sampling_params)
|
||||
|
||||
# load model with two shards and serialize with encryption
|
||||
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
|
||||
model_path = str(tmp_path / model_ref / "model-%02d.tensors")
|
||||
key_path = tmp_path / (model_ref + ".key")
|
||||
|
||||
tensorizer_config = TensorizerConfig(
|
||||
@ -242,13 +285,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner,
|
||||
tmp_path, model_path):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path))
|
||||
args = EngineArgs(model=model_ref, device="cuda")
|
||||
args = EngineArgs(model=model_ref)
|
||||
|
||||
with vllm_runner(model_ref) as vllm_model:
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
@ -264,3 +306,243 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
# noqa: E501
|
||||
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref):
|
||||
# For backwards compatibility, ensure Tensorizer can be still be loaded
|
||||
# for inference by passing the model reference name, not a local/S3 dir,
|
||||
# and the location of the model tensors
|
||||
|
||||
model_dir = just_serialize_model_tensors
|
||||
|
||||
extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"}
|
||||
|
||||
## Start OpenAI API server
|
||||
args = [
|
||||
"--load-format",
|
||||
"tensorizer",
|
||||
"--model-loader-extra-config",
|
||||
json.dumps(extra_config),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(model_ref, args):
|
||||
# This test only concerns itself with being able to load the model
|
||||
# and successfully initialize the server
|
||||
pass
|
||||
|
||||
|
||||
def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path):
|
||||
|
||||
serialization_params = {
|
||||
"limit_cpu_concurrency": 2,
|
||||
}
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path),
|
||||
serialization_kwargs=serialization_params)
|
||||
llm = LLM(model=model_ref, )
|
||||
|
||||
def serialization_test(self, *args, **kwargs):
|
||||
# This is performed in the ephemeral worker process, so monkey-patching
|
||||
# will actually work, and cleanup is guaranteed so don't
|
||||
# need to reset things
|
||||
|
||||
original_dict = serialization_params
|
||||
to_compare = {}
|
||||
|
||||
original = tensorizer.serialization.TensorSerializer.__init__
|
||||
|
||||
def tensorizer_serializer_wrapper(self, *args, **kwargs):
|
||||
nonlocal to_compare
|
||||
to_compare = kwargs.copy()
|
||||
return original(self, *args, **kwargs)
|
||||
|
||||
tensorizer.serialization.TensorSerializer.__init__ = (
|
||||
tensorizer_serializer_wrapper)
|
||||
|
||||
tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"])
|
||||
self.save_tensorized_model(tensorizer_config=tensorizer_config, )
|
||||
return to_compare | original_dict == to_compare
|
||||
|
||||
kwargs = {"tensorizer_config": config.to_serializable()}
|
||||
|
||||
assert assert_from_collective_rpc(llm, serialization_test, kwargs)
|
||||
|
||||
|
||||
def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(
|
||||
tmp_path, capfd):
|
||||
|
||||
deserialization_kwargs = {
|
||||
"num_readers": "bar", # illegal value
|
||||
}
|
||||
|
||||
serialization_params = {
|
||||
"limit_cpu_concurrency": 2,
|
||||
}
|
||||
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path),
|
||||
serialization_kwargs=serialization_params)
|
||||
|
||||
args = EngineArgs(model=model_ref)
|
||||
tensorize_vllm_model(args, config)
|
||||
|
||||
loader_tc = TensorizerConfig(
|
||||
tensorizer_uri=str(model_path),
|
||||
deserialization_kwargs=deserialization_kwargs,
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=loader_tc.to_serializable(),
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor = DummyExecutor(vllm_config)
|
||||
|
||||
assert_specific_tensorizer_error_is_raised(
|
||||
executor,
|
||||
tensorizer.serialization.TensorDeserializer,
|
||||
"__init__",
|
||||
TypeError,
|
||||
)
|
||||
|
||||
|
||||
def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd):
|
||||
|
||||
deserialization_kwargs = {
|
||||
"num_readers": 1,
|
||||
}
|
||||
|
||||
serialization_params = {
|
||||
"limit_cpu_concurrency": 2,
|
||||
}
|
||||
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path),
|
||||
serialization_kwargs=serialization_params)
|
||||
|
||||
args = EngineArgs(model=model_ref)
|
||||
tensorize_vllm_model(args, config)
|
||||
|
||||
stream_kwargs = {"mode": "foo"}
|
||||
|
||||
loader_tc = TensorizerConfig(
|
||||
tensorizer_uri=str(model_path),
|
||||
deserialization_kwargs=deserialization_kwargs,
|
||||
stream_kwargs=stream_kwargs,
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=loader_tc.to_serializable(),
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor = DummyExecutor(vllm_config)
|
||||
|
||||
assert_specific_tensorizer_error_is_raised(
|
||||
executor,
|
||||
vllm.model_executor.model_loader.tensorizer,
|
||||
"open_stream",
|
||||
ValueError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serialize_and_serve_entrypoints(tmp_path):
|
||||
model_ref = "facebook/opt-125m"
|
||||
|
||||
suffix = "test"
|
||||
try:
|
||||
result = subprocess.run([
|
||||
sys.executable,
|
||||
f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
|
||||
model_ref, "serialize", "--serialized-directory",
|
||||
str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
|
||||
'{"limit_cpu_concurrency": 4}'
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Tensorizing failed.")
|
||||
print("STDOUT:\n", e.stdout)
|
||||
print("STDERR:\n", e.stderr)
|
||||
raise
|
||||
|
||||
assert "Successfully serialized" in result.stdout
|
||||
|
||||
# Next, try to serve with vllm serve
|
||||
model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors"
|
||||
|
||||
model_loader_extra_config = {
|
||||
"tensorizer_uri": str(model_uri),
|
||||
"stream_kwargs": {
|
||||
"force_http": False,
|
||||
},
|
||||
"deserialization_kwargs": {
|
||||
"verify_hash": True,
|
||||
"num_readers": 8,
|
||||
}
|
||||
}
|
||||
|
||||
cmd = [
|
||||
"-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost",
|
||||
"--load-format", "tensorizer", model_ref,
|
||||
"--model-loader-extra-config",
|
||||
json.dumps(model_loader_extra_config, indent=2)
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
assert proc.stdout is not None
|
||||
fut = proc.stdout.readuntil(b"Application startup complete.")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(fut, 180)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Server did not start successfully")
|
||||
finally:
|
||||
proc.terminate()
|
||||
await proc.communicate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS)
|
||||
def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd,
|
||||
illegal_value):
|
||||
|
||||
serialization_params = {
|
||||
"limit_cpu_concurrency": 2,
|
||||
}
|
||||
|
||||
model_ref = "facebook/opt-125m"
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
config = TensorizerConfig(tensorizer_uri=str(model_path),
|
||||
serialization_kwargs=serialization_params)
|
||||
|
||||
args = EngineArgs(model=model_ref)
|
||||
tensorize_vllm_model(args, config)
|
||||
|
||||
loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"}
|
||||
|
||||
try:
|
||||
vllm_runner(
|
||||
model_ref,
|
||||
load_format="tensorizer",
|
||||
model_loader_extra_config=loader_tc,
|
||||
)
|
||||
except RuntimeError:
|
||||
out, err = capfd.readouterr()
|
||||
combined_output = out + err
|
||||
assert (f"ValueError: {illegal_value} is not an allowed "
|
||||
f"Tensorizer argument.") in combined_output
|
||||
|
||||
@ -686,8 +686,11 @@ class ModelConfig:
|
||||
|
||||
# If tokenizer is same as model, download to same directory
|
||||
if model == tokenizer:
|
||||
s3_model.pull_files(
|
||||
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
s3_model.pull_files(model,
|
||||
ignore_pattern=[
|
||||
"*.pt", "*.safetensors", "*.bin",
|
||||
"*.tensors"
|
||||
])
|
||||
self.tokenizer = s3_model.dir
|
||||
return
|
||||
|
||||
@ -695,7 +698,8 @@ class ModelConfig:
|
||||
if is_s3(tokenizer):
|
||||
s3_tokenizer = S3Model()
|
||||
s3_tokenizer.pull_files(
|
||||
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
model,
|
||||
ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"])
|
||||
self.tokenizer = s3_tokenizer.dir
|
||||
|
||||
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
|
||||
|
||||
@ -58,7 +58,8 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||
|
||||
def _parse_type(val: str) -> T:
|
||||
try:
|
||||
if return_type is json.loads and not re.match("^{.*}$", val):
|
||||
if return_type is json.loads and not re.match(
|
||||
r"(?s)^\s*{.*}\s*$", val):
|
||||
return cast(T, nullable_kvs(val))
|
||||
return return_type(val)
|
||||
except ValueError as e:
|
||||
@ -80,7 +81,7 @@ def optional_type(
|
||||
|
||||
|
||||
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
||||
if not re.match("^{.*}$", val):
|
||||
if not re.match(r"(?s)^\s*{.*}\s*$", val):
|
||||
return str(val)
|
||||
return optional_type(json.loads)(val)
|
||||
|
||||
@ -1001,11 +1002,42 @@ class EngineArgs:
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
)
|
||||
|
||||
def valid_tensorizer_config_provided(self) -> bool:
|
||||
"""
|
||||
Checks if a parseable TensorizerConfig was passed to
|
||||
self.model_loader_extra_config. It first checks if the config passed
|
||||
is a dict or a TensorizerConfig object directly, and if the latter is
|
||||
true (by checking that the object has TensorizerConfig's
|
||||
.to_serializable() method), converts it in to a serializable dict
|
||||
format
|
||||
"""
|
||||
if self.model_loader_extra_config:
|
||||
if hasattr(self.model_loader_extra_config, "to_serializable"):
|
||||
self.model_loader_extra_config = (
|
||||
self.model_loader_extra_config.to_serializable())
|
||||
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
|
||||
try:
|
||||
self.model_loader_extra_config[allowed_to_pass]
|
||||
return False
|
||||
except KeyError:
|
||||
pass
|
||||
return True
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
|
||||
if self.quantization == "bitsandbytes":
|
||||
self.load_format = "bitsandbytes"
|
||||
|
||||
if (self.load_format == "tensorizer"
|
||||
and self.valid_tensorizer_config_provided()):
|
||||
logger.info("Inferring Tensorizer args from %s", self.model)
|
||||
self.model_loader_extra_config = {"tensorizer_dir": self.model}
|
||||
else:
|
||||
logger.info(
|
||||
"Using Tensorizer args from --model-loader-extra-config. "
|
||||
"Note that you can now simply pass the S3 directory in the "
|
||||
"model tag instead of providing the JSON string.")
|
||||
|
||||
return LoadConfig(
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
|
||||
@ -245,9 +245,10 @@ class LoRAModel(AdapterModel):
|
||||
lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
|
||||
"adapter_model.tensors")
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
tensors = TensorDeserializer(lora_tensor_path,
|
||||
dtype=tensorizer_config.dtype,
|
||||
**tensorizer_args.deserializer_params)
|
||||
tensors = TensorDeserializer(
|
||||
lora_tensor_path,
|
||||
dtype=tensorizer_config.dtype,
|
||||
**tensorizer_args.deserialization_kwargs)
|
||||
check_unexpected_modules(tensors)
|
||||
|
||||
elif os.path.isfile(lora_tensor_path):
|
||||
|
||||
@ -106,7 +106,7 @@ class PEFTHelper:
|
||||
"adapter_config.json")
|
||||
with open_stream(lora_config_path,
|
||||
mode="rb",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
**tensorizer_args.stream_kwargs) as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info("Successfully deserialized LoRA config from %s",
|
||||
|
||||
@ -5,18 +5,18 @@ import argparse
|
||||
import contextlib
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union
|
||||
from collections.abc import Generator, MutableMapping
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch import nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from transformers import PretrainedConfig
|
||||
@ -39,10 +39,6 @@ try:
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
open_stream,
|
||||
mode=mode,
|
||||
) for mode in ("rb", "wb+"))
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
@ -54,9 +50,6 @@ except ImportError:
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
_read_stream = tensorizer.placeholder_attr("_read_stream")
|
||||
_write_stream = tensorizer.placeholder_attr("_write_stream")
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
|
||||
@ -66,6 +59,23 @@ __all__ = [
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_deserialization_uri(uri: Optional[str]) -> bool:
|
||||
if uri:
|
||||
scheme = uri.lower().split("://")[0]
|
||||
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
|
||||
return False
|
||||
|
||||
|
||||
def tensorizer_kwargs_arg(value):
|
||||
loaded = json.loads(value)
|
||||
if not isinstance(loaded, dict):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Not deserializable to dict: {value}. serialization_kwargs and "
|
||||
f"deserialization_kwargs must be "
|
||||
f"deserializable from a JSON string to a dictionary. ")
|
||||
return loaded
|
||||
|
||||
|
||||
class MetaTensorMode(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
@ -137,101 +147,45 @@ class _NoInitOrTensorImpl:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: Union[str, None] = None
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
class TensorizerConfig(MutableMapping):
|
||||
tensorizer_uri: Optional[str] = None
|
||||
tensorizer_dir: Optional[str] = None
|
||||
vllm_tensorized: Optional[bool] = None
|
||||
verify_hash: Optional[bool] = None
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[type[torch.nn.Module]] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Optional[Union[str, torch.dtype]] = None
|
||||
lora_dir: Optional[str] = None
|
||||
_is_sharded: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = isinstance(self.tensorizer_uri, str) \
|
||||
and re.search(r'%0\dd', self.tensorizer_uri) is not None
|
||||
if not self.tensorizer_uri and not self.lora_dir:
|
||||
raise ValueError("tensorizer_uri must be provided.")
|
||||
if not self.tensorizer_uri and self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
|
||||
"provided.")
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
self.lora_dir = self.tensorizer_dir
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
|
||||
cfg = TensorizerConfig(*args, **kwargs)
|
||||
return dataclasses.asdict(cfg)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
tensorizer_args = {
|
||||
"tensorizer_uri": self.tensorizer_uri,
|
||||
"vllm_tensorized": self.vllm_tensorized,
|
||||
"verify_hash": self.verify_hash,
|
||||
"num_readers": self.num_readers,
|
||||
"encryption_keyfile": self.encryption_keyfile,
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 \
|
||||
and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard")
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri,
|
||||
**tensorizer_args.stream_params)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
|
||||
bytes, os.PathLike, int]
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
stream_kwargs: Optional[dict[str, Any]] = None
|
||||
serialization_kwargs: Optional[dict[str, Any]] = None
|
||||
deserialization_kwargs: Optional[dict[str, Any]] = None
|
||||
_extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False,
|
||||
default=None)
|
||||
model_class: Optional[type[torch.nn.Module]] = field(init=False,
|
||||
default=None)
|
||||
hf_config: Optional[PretrainedConfig] = field(init=False, default=None)
|
||||
dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None)
|
||||
_is_sharded: bool = field(init=False, default=False)
|
||||
_fields: ClassVar[tuple[str, ...]]
|
||||
_keys: ClassVar[frozenset[str]]
|
||||
"""
|
||||
Args for the TensorizerAgent class. These are used to configure the behavior
|
||||
of the TensorDeserializer when loading tensors from a serialized model.
|
||||
|
||||
Args:
|
||||
Args for the TensorizerConfig class. These are used to configure the
|
||||
behavior of model serialization and deserialization using Tensorizer.
|
||||
|
||||
Args:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI. This is a required field unless lora_dir is
|
||||
provided and the config is meant to be used for the
|
||||
`tensorize_lora_adapter` function.
|
||||
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
|
||||
`lora_dir` is passed to this object's initializer, this is a required
|
||||
argument.
|
||||
tensorizer_dir: Path to a directory containing serialized model tensors,
|
||||
and all other potential model artifacts to load the model, such as
|
||||
configs and tokenizer files. Can be passed instead of `tensorizer_uri`
|
||||
where the `model.tensors` file will be assumed to be in this
|
||||
directory.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
@ -256,34 +210,174 @@ class TensorizerArgs:
|
||||
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||
S3_ENDPOINT_URL environment variable.
|
||||
lora_dir: Path to a directory containing LoRA adapter artifacts for
|
||||
serialization or deserialization. When serializing LoRA adapters
|
||||
this is the only necessary parameter to pass to this object's
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self.file_obj = self.tensorizer_uri
|
||||
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
|
||||
self.s3_secret_access_key = (self.s3_secret_access_key
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = isinstance(self.tensorizer_uri, str) \
|
||||
and re.search(r'%0\dd', self.tensorizer_uri) is not None
|
||||
|
||||
if self.tensorizer_dir and self.tensorizer_uri:
|
||||
raise ValueError(
|
||||
"Either tensorizer_dir or tensorizer_uri must be provided, "
|
||||
"not both.")
|
||||
if self.tensorizer_dir and self.lora_dir:
|
||||
raise ValueError(
|
||||
"Only one of tensorizer_dir or lora_dir may be specified. "
|
||||
"Use lora_dir exclusively when serializing LoRA adapters, "
|
||||
"and tensorizer_dir or tensorizer_uri otherwise.")
|
||||
if not self.tensorizer_uri:
|
||||
if self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
elif self.tensorizer_dir:
|
||||
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
|
||||
else:
|
||||
raise ValueError("Unable to resolve tensorizer_uri. "
|
||||
"A valid tensorizer_uri or tensorizer_dir "
|
||||
"must be provided for deserialization, and a "
|
||||
"valid tensorizer_uri, tensorizer_uri, or "
|
||||
"lora_dir for serialization.")
|
||||
else:
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
|
||||
if not self.serialization_kwargs:
|
||||
self.serialization_kwargs = {}
|
||||
if not self.deserialization_kwargs:
|
||||
self.deserialization_kwargs = {}
|
||||
|
||||
def to_serializable(self) -> dict[str, Any]:
|
||||
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
|
||||
# support for morphing back and forth between itself and its dict
|
||||
# representation
|
||||
|
||||
# TensorizerConfig's representation as a dictionary is meant to be
|
||||
# linked to TensorizerConfig in such a way that the following is
|
||||
# technically initializable:
|
||||
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
|
||||
|
||||
# This means the dict must not retain non-initializable parameters
|
||||
# and post-init attribute states
|
||||
|
||||
# Also don't want to retain private and unset parameters, so only retain
|
||||
# not None values and public attributes
|
||||
|
||||
raw_tc_dict = asdict(self)
|
||||
blacklisted = []
|
||||
|
||||
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
tc_dict = {}
|
||||
for k, v in raw_tc_dict.items():
|
||||
if (k not in blacklisted and k not in tc_dict
|
||||
and not k.startswith("_") and v is not None):
|
||||
tc_dict[k] = v
|
||||
|
||||
return tc_dict
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
return TensorizerArgs(self) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 \
|
||||
and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard")
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri,
|
||||
**tensorizer_args.stream_kwargs)
|
||||
|
||||
def keys(self):
|
||||
return self._keys
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._fields)
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
if item not in self.keys():
|
||||
raise KeyError(item)
|
||||
return getattr(self, item)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if key not in self.keys():
|
||||
# Disallow modifying invalid keys
|
||||
raise KeyError(key)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __delitem__(self, key, /):
|
||||
if key not in self.keys():
|
||||
raise KeyError(key)
|
||||
delattr(self, key)
|
||||
|
||||
|
||||
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
|
||||
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: Optional[str] = None
|
||||
tensorizer_dir: Optional[str] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig):
|
||||
for k, v in tensorizer_config.items():
|
||||
setattr(self, k, v)
|
||||
self.file_obj = tensorizer_config.tensorizer_uri
|
||||
self.s3_access_key_id = (tensorizer_config.s3_access_key_id
|
||||
or envs.S3_ACCESS_KEY_ID)
|
||||
self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key
|
||||
or envs.S3_SECRET_ACCESS_KEY)
|
||||
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
self.stream_params = {
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
|
||||
self.stream_kwargs = {
|
||||
"s3_access_key_id": tensorizer_config.s3_access_key_id,
|
||||
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
|
||||
"s3_endpoint": tensorizer_config.s3_endpoint,
|
||||
**(tensorizer_config.stream_kwargs or {})
|
||||
}
|
||||
|
||||
self.deserializer_params = {
|
||||
"verify_hash": self.verify_hash,
|
||||
"encryption": self.encryption_keyfile,
|
||||
"num_readers": self.num_readers
|
||||
self.deserialization_kwargs = {
|
||||
"verify_hash": tensorizer_config.verify_hash,
|
||||
"encryption": tensorizer_config.encryption_keyfile,
|
||||
"num_readers": tensorizer_config.num_readers,
|
||||
**(tensorizer_config.deserialization_kwargs or {})
|
||||
}
|
||||
|
||||
if self.encryption_keyfile:
|
||||
with open_stream(
|
||||
self.encryption_keyfile,
|
||||
**self.stream_params,
|
||||
tensorizer_config.encryption_keyfile,
|
||||
**self.stream_kwargs,
|
||||
) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserializer_params['encryption'] = decryption_params
|
||||
self.deserialization_kwargs['encryption'] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
@ -405,15 +499,22 @@ def init_tensorizer_model(tensorizer_config: TensorizerConfig,
|
||||
def deserialize_tensorizer_model(model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||||
raise ValueError(
|
||||
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||||
f"tensorizer URI. Please check that the URI is correct. "
|
||||
f"It must either point to a local existing file, or have a "
|
||||
f"S3, HTTP or HTTPS scheme.")
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
with _read_stream(
|
||||
with open_stream(
|
||||
tensorizer_config.tensorizer_uri,
|
||||
**tensorizer_args.stream_params) as stream, TensorDeserializer(
|
||||
mode="rb",
|
||||
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=f'cuda:{torch.cuda.current_device()}',
|
||||
**tensorizer_args.deserializer_params) as deserializer:
|
||||
device=torch.device("cuda", torch.cuda.current_device()),
|
||||
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
|
||||
@ -442,9 +543,9 @@ def tensorizer_weights_iterator(
|
||||
"examples/others/tensorize_vllm_model.py example script "
|
||||
"for serializing vLLM models.")
|
||||
|
||||
deserializer_args = tensorizer_args.deserializer_params
|
||||
stream_params = tensorizer_args.stream_params
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||
deserializer_args = tensorizer_args.deserialization_kwargs
|
||||
stream_kwargs = tensorizer_args.stream_kwargs
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
yield from state.items()
|
||||
@ -465,8 +566,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
"""
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
deserializer = TensorDeserializer(open_stream(
|
||||
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
|
||||
**tensorizer_args.deserializer_params,
|
||||
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
|
||||
**tensorizer_args.deserialization_kwargs,
|
||||
lazy_load=True)
|
||||
if tensorizer_config.vllm_tensorized:
|
||||
logger.warning(
|
||||
@ -477,13 +578,41 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
return ".vllm_tensorized_marker" in deserializer
|
||||
|
||||
|
||||
def serialize_extra_artifacts(
|
||||
tensorizer_args: TensorizerArgs,
|
||||
served_model_name: Union[str, list[str], None]) -> None:
|
||||
if not isinstance(served_model_name, str):
|
||||
raise ValueError(
|
||||
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||||
f"not {type(served_model_name)}.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
snapshot_download(served_model_name,
|
||||
local_dir=tmpdir,
|
||||
ignore_patterns=[
|
||||
"*.pt", "*.safetensors", "*.bin", "*.cache",
|
||||
"*.gitattributes", "*.md"
|
||||
])
|
||||
for artifact in os.scandir(tmpdir):
|
||||
if not artifact.is_file():
|
||||
continue
|
||||
with open(artifact.path, "rb") as f, open_stream(
|
||||
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as stream:
|
||||
logger.info("Writing artifact %s", artifact.name)
|
||||
stream.write(f.read())
|
||||
|
||||
|
||||
def serialize_vllm_model(
|
||||
model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
model_config: "ModelConfig",
|
||||
) -> nn.Module:
|
||||
model.register_parameter(
|
||||
"vllm_tensorized_marker",
|
||||
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
encryption_params = None
|
||||
@ -497,10 +626,16 @@ def serialize_vllm_model(
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
output_file = output_file % get_tensor_model_parallel_rank()
|
||||
|
||||
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
with open_stream(output_file, mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as stream:
|
||||
serializer = TensorSerializer(stream,
|
||||
encryption=encryption_params,
|
||||
**tensorizer_config.serialization_kwargs)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
|
||||
|
||||
logger.info("Successfully serialized model to %s", str(output_file))
|
||||
return model
|
||||
|
||||
@ -522,8 +657,9 @@ def tensorize_vllm_model(engine_args: "EngineArgs",
|
||||
if generate_keyfile and (keyfile :=
|
||||
tensorizer_config.encryption_keyfile) is not None:
|
||||
encryption_params = EncryptionParams.random()
|
||||
with _write_stream(
|
||||
with open_stream(
|
||||
keyfile,
|
||||
mode="wb+",
|
||||
s3_access_key_id=tensorizer_config.s3_access_key_id,
|
||||
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
|
||||
s3_endpoint=tensorizer_config.s3_endpoint,
|
||||
@ -537,13 +673,13 @@ def tensorize_vllm_model(engine_args: "EngineArgs",
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
engine.model_executor.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
|
||||
)
|
||||
else:
|
||||
engine = V1LLMEngine.from_vllm_config(engine_config)
|
||||
engine.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
|
||||
)
|
||||
|
||||
|
||||
@ -586,14 +722,14 @@ def tensorize_lora_adapter(lora_path: str,
|
||||
|
||||
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
**tensorizer_args.stream_kwargs) as f:
|
||||
|
||||
f.write(json.dumps(config).encode("utf-8"))
|
||||
|
||||
lora_uri = (f"{tensorizer_config.lora_dir}"
|
||||
f"/adapter_model.tensors")
|
||||
with open_stream(lora_uri, mode="wb+",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
**tensorizer_args.stream_kwargs) as f:
|
||||
serializer = TensorSerializer(f)
|
||||
serializer.write_state_dict(tensors)
|
||||
serializer.close()
|
||||
|
||||
@ -20,6 +20,18 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BLACKLISTED_TENSORIZER_ARGS = {
|
||||
"device", # vLLM decides this
|
||||
"dtype", # vLLM decides this
|
||||
"mode", # Not meant to be configurable by the user
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict):
|
||||
for k, v in config.items():
|
||||
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
|
||||
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
@ -29,6 +41,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
validate_config(load_config.model_loader_extra_config)
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config)
|
||||
|
||||
@ -118,10 +131,12 @@ class TensorizerLoader(BaseModelLoader):
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: Union[TensorizerConfig, dict],
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
@ -1820,6 +1820,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
TensorizerLoader.save_model(
|
||||
self.model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
def _get_prompt_logprobs_dict(
|
||||
|
||||
@ -1246,6 +1246,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
TensorizerLoader.save_model(
|
||||
self.model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
def get_max_block_per_batch(self) -> int:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user