mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:04:57 +08:00
[Frontend] [Core] perf: Automatically detect vLLM-tensorized model, update tensorizer to version 2.9.0 (#4208)
This commit is contained in:
parent
0fca3cdcf2
commit
8bc68e198c
@ -60,11 +60,13 @@ steps:
|
|||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
# install aws cli for llava_example.py
|
# install aws cli for llava_example.py
|
||||||
- pip install awscli
|
# install tensorizer for tensorize_vllm_model.py
|
||||||
|
- pip install awscli tensorizer
|
||||||
- python3 offline_inference.py
|
- python3 offline_inference.py
|
||||||
- python3 offline_inference_with_prefix.py
|
- python3 offline_inference_with_prefix.py
|
||||||
- python3 llm_engine_example.py
|
- python3 llm_engine_example.py
|
||||||
- python3 llava_example.py
|
- python3 llava_example.py
|
||||||
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
|
|
||||||
- label: Kernels Test %N
|
- label: Kernels Test %N
|
||||||
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
|
|||||||
@ -1,23 +1,20 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import torch
|
from tensorizer import stream_io
|
||||||
import torch.nn as nn
|
|
||||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
|
||||||
TensorSerializer, stream_io)
|
|
||||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
|
||||||
|
|
||||||
from vllm.distributed import initialize_model_parallel
|
from vllm import LLM
|
||||||
|
from vllm.distributed import (init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
|
||||||
from vllm.model_executor.models import ModelRegistry
|
TensorizerConfig,
|
||||||
|
serialize_vllm_model)
|
||||||
|
|
||||||
# yapf conflicts with isort for this docstring
|
# yapf conflicts with isort for this docstring
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -27,25 +24,25 @@ deserialize vLLM models. These models can be loaded using tensorizer
|
|||||||
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
||||||
or locally. Tensor encryption and decryption is also supported, although
|
or locally. Tensor encryption and decryption is also supported, although
|
||||||
libsodium must be installed to use it. Install vllm with tensorizer support
|
libsodium must be installed to use it. Install vllm with tensorizer support
|
||||||
using `pip install vllm[tensorizer]`.
|
using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit
|
||||||
|
https://github.com/coreweave/tensorizer
|
||||||
|
|
||||||
To serialize a model, install vLLM from source, then run something
|
To serialize a model, install vLLM from source, then run something
|
||||||
like this from the root level of this repository:
|
like this from the root level of this repository:
|
||||||
|
|
||||||
python -m examples.tensorize_vllm_model \
|
python -m examples.tensorize_vllm_model \
|
||||||
--model EleutherAI/gpt-j-6B \
|
--model facebook/opt-125m \
|
||||||
--dtype float16 \
|
|
||||||
serialize \
|
serialize \
|
||||||
--serialized-directory s3://my-bucket/ \
|
--serialized-directory s3://my-bucket \
|
||||||
--suffix vllm
|
--suffix v1
|
||||||
|
|
||||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||||
and saves it to your S3 bucket. A local directory can also be used. This
|
and saves it to your S3 bucket. A local directory can also be used. This
|
||||||
assumes your S3 credentials are specified as environment variables
|
assumes your S3 credentials are specified as environment variables
|
||||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and
|
||||||
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
|
`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide
|
||||||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
|
`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint`
|
||||||
script.
|
as CLI args to this script.
|
||||||
|
|
||||||
You can also encrypt the model weights with a randomly-generated key by
|
You can also encrypt the model weights with a randomly-generated key by
|
||||||
providing a `--keyfile` argument.
|
providing a `--keyfile` argument.
|
||||||
@ -57,7 +54,7 @@ python -m examples.tensorize_vllm_model \
|
|||||||
--model EleutherAI/gpt-j-6B \
|
--model EleutherAI/gpt-j-6B \
|
||||||
--dtype float16 \
|
--dtype float16 \
|
||||||
deserialize \
|
deserialize \
|
||||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors
|
||||||
|
|
||||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||||
|
|
||||||
@ -71,26 +68,30 @@ Or for deserializing:
|
|||||||
|
|
||||||
`python -m examples.tensorize_vllm_model deserialize --help`.
|
`python -m examples.tensorize_vllm_model deserialize --help`.
|
||||||
|
|
||||||
Once a model is serialized, it can be used to load the model when running the
|
Once a model is serialized, tensorizer can be invoked with the `LLM` class
|
||||||
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
|
directly to load models:
|
||||||
the `--tensorizer-uri` CLI argument that is functionally the same as the
|
|
||||||
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
|
|
||||||
signify that the model to be deserialized is a vLLM model, rather than a
|
|
||||||
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
|
|
||||||
in the same inference server, albeit without the speed optimizations. To
|
|
||||||
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
|
|
||||||
to provide the path to the keyfile used to encrypt the model weights. For
|
|
||||||
information on all the arguments that can be used to configure tensorizer's
|
|
||||||
deserialization, check out the tensorizer options argument group in the
|
|
||||||
`vllm/entrypoints/openai/api_server.py` script with `--help`.
|
|
||||||
|
|
||||||
Tensorizer can also be invoked with the `LLM` class directly to load models:
|
|
||||||
|
|
||||||
llm = LLM(model="facebook/opt-125m",
|
llm = LLM(model="facebook/opt-125m",
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
tensorizer_uri=path_to_opt_tensors,
|
model_loader_extra_config=TensorizerConfig(
|
||||||
|
tensorizer_uri = path_to_tensors,
|
||||||
num_readers=3,
|
num_readers=3,
|
||||||
vllm_tensorized=True)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
A serialized model can be used during model loading for the vLLM OpenAI
|
||||||
|
inference server. `model_loader_extra_config` is exposed as the CLI arg
|
||||||
|
`--model-loader-extra-config`, and accepts a JSON string literal of the
|
||||||
|
TensorizerConfig arguments desired.
|
||||||
|
|
||||||
|
In order to see all of the available arguments usable to configure
|
||||||
|
loading with tensorizer that are given to `TensorizerConfig`, run:
|
||||||
|
|
||||||
|
`python -m examples.tensorize_vllm_model deserialize --help`
|
||||||
|
|
||||||
|
under the `tensorizer options` section. These can also be used for
|
||||||
|
deserialization in this example script, although `--tensorizer-uri` and
|
||||||
|
`--path-to-tensors` are functionally the same in this case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -158,95 +159,35 @@ def parse_args():
|
|||||||
help=("Path to a binary key to use to decrypt the model weights,"
|
help=("Path to a binary key to use to decrypt the model weights,"
|
||||||
" if the model was serialized with encryption"))
|
" if the model was serialized with encryption"))
|
||||||
|
|
||||||
|
TensorizerArgs.add_cli_args(deserialize_parser)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def make_model_contiguous(model):
|
|
||||||
# Ensure tensors are saved in memory contiguously
|
|
||||||
for param in model.parameters():
|
|
||||||
param.data = param.data.contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|
||||||
architectures = getattr(config, "architectures", [])
|
|
||||||
for arch in architectures:
|
|
||||||
model_cls = ModelRegistry.load_model_cls(arch)
|
|
||||||
if model_cls is not None:
|
|
||||||
return model_cls
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
|
||||||
|
|
||||||
|
|
||||||
def serialize():
|
|
||||||
|
|
||||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
|
||||||
dataclasses.fields(EngineArgs)}
|
|
||||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
|
||||||
|
|
||||||
model = (engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|
||||||
encryption_params = EncryptionParams.random() if keyfile else None
|
|
||||||
if keyfile:
|
|
||||||
with _write_stream(keyfile) as stream:
|
|
||||||
stream.write(encryption_params.key)
|
|
||||||
|
|
||||||
with _write_stream(model_path) as stream:
|
|
||||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
|
||||||
serializer.write_module(model)
|
|
||||||
serializer.close()
|
|
||||||
|
|
||||||
print("Serialization complete. Model tensors saved to", model_path)
|
|
||||||
if keyfile:
|
|
||||||
print("Key saved to", keyfile)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize():
|
def deserialize():
|
||||||
config = AutoConfig.from_pretrained(model_ref)
|
llm = LLM(model=args.model,
|
||||||
|
load_format="tensorizer",
|
||||||
with no_init_or_tensor():
|
model_loader_extra_config=tensorizer_config
|
||||||
model_class = _get_vllm_model_architecture(config)
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
before_mem = get_mem_usage()
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
if keyfile:
|
|
||||||
with _read_stream(keyfile) as stream:
|
|
||||||
key = stream.read()
|
|
||||||
decryption_params = DecryptionParams.from_key(key)
|
|
||||||
tensorizer_args.deserializer_params['encryption'] = \
|
|
||||||
decryption_params
|
|
||||||
|
|
||||||
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
|
||||||
stream, **tensorizer_args.deserializer_params) as deserializer:
|
|
||||||
deserializer.load_into_module(model)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
# Brag about how fast we are.
|
|
||||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
|
||||||
duration = end - start
|
|
||||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
|
||||||
after_mem = get_mem_usage()
|
|
||||||
print(
|
|
||||||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
|
||||||
)
|
)
|
||||||
print(f"Memory usage before: {before_mem}")
|
return llm
|
||||||
print(f"Memory usage after: {after_mem}")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
|
||||||
or None)
|
or os.environ.get("S3_ACCESS_KEY_ID", None))
|
||||||
s3_secret_access_key = (args.s3_secret_access_key
|
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
|
||||||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
|
||||||
|
s3_endpoint = (getattr(args, 's3_endpoint', None)
|
||||||
|
or os.environ.get("S3_ENDPOINT_URL", None))
|
||||||
|
|
||||||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
credentials = {
|
||||||
|
"s3_access_key_id": s3_access_key_id,
|
||||||
|
"s3_secret_access_key": s3_secret_access_key,
|
||||||
|
"s3_endpoint": s3_endpoint
|
||||||
|
}
|
||||||
|
|
||||||
_read_stream, _write_stream = (partial(
|
_read_stream, _write_stream = (partial(
|
||||||
stream_io.open_stream,
|
stream_io.open_stream,
|
||||||
@ -263,20 +204,41 @@ model_name = model_ref.split("/")[1]
|
|||||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
os.environ["MASTER_PORT"] = "8080"
|
os.environ["MASTER_PORT"] = "8080"
|
||||||
|
|
||||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
init_distributed_environment(world_size=1, rank=0, local_rank=0)
|
||||||
initialize_model_parallel()
|
initialize_model_parallel()
|
||||||
|
|
||||||
keyfile = args.keyfile if args.keyfile else None
|
keyfile = args.keyfile if args.keyfile else None
|
||||||
|
|
||||||
|
|
||||||
|
if args.model_loader_extra_config:
|
||||||
|
config = json.loads(args.model_loader_extra_config)
|
||||||
|
tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
|
||||||
|
tensorizer_args.tensorizer_uri = args.path_to_tensors
|
||||||
|
else:
|
||||||
|
tensorizer_args = None
|
||||||
|
|
||||||
if args.command == "serialize":
|
if args.command == "serialize":
|
||||||
|
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||||
|
dataclasses.fields(EngineArgs)}
|
||||||
|
|
||||||
|
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
input_dir = args.serialized_directory.rstrip('/')
|
input_dir = args.serialized_directory.rstrip('/')
|
||||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||||
model_path = f"{base_path}/model.tensors"
|
model_path = f"{base_path}/model.tensors"
|
||||||
serialize()
|
tensorizer_config = TensorizerConfig(
|
||||||
|
tensorizer_uri=model_path,
|
||||||
|
**credentials)
|
||||||
|
serialize_vllm_model(engine, tensorizer_config, keyfile)
|
||||||
elif args.command == "deserialize":
|
elif args.command == "deserialize":
|
||||||
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
if not tensorizer_args:
|
||||||
model_path = args.path_to_tensors
|
tensorizer_config = TensorizerConfig(
|
||||||
|
tensorizer_uri=args.path_to_tensors,
|
||||||
|
encryption_keyfile = keyfile,
|
||||||
|
**credentials
|
||||||
|
)
|
||||||
deserialize()
|
deserialize()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either serialize or deserialize must be specified.")
|
raise ValueError("Either serialize or deserialize must be specified.")
|
||||||
|
|||||||
@ -14,7 +14,7 @@ types-setuptools
|
|||||||
|
|
||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
tensorizer==2.9.0
|
tensorizer>=2.9.0
|
||||||
pytest-forked
|
pytest-forked
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
pytest-rerunfailures
|
pytest-rerunfailures
|
||||||
|
|||||||
2
setup.py
2
setup.py
@ -426,7 +426,7 @@ setup(
|
|||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
extras_require={
|
extras_require={
|
||||||
"tensorizer": ["tensorizer==2.9.0"],
|
"tensorizer": ["tensorizer>=2.9.0"],
|
||||||
},
|
},
|
||||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
|
|||||||
@ -1,245 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import dataclasses
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from functools import partial
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
|
||||||
TensorSerializer, stream_io)
|
|
||||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
|
||||||
|
|
||||||
from vllm.distributed import (init_distributed_environment,
|
|
||||||
initialize_model_parallel)
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
|
||||||
from vllm.model_executor.models import ModelRegistry
|
|
||||||
|
|
||||||
# yapf conflicts with isort for this docstring
|
|
||||||
# yapf: disable
|
|
||||||
"""
|
|
||||||
tensorize_vllm_model.py is a script that can be used to serialize and
|
|
||||||
deserialize vLLM models. These models can be loaded using tensorizer directly
|
|
||||||
to the GPU extremely quickly. Tensor encryption and decryption is also
|
|
||||||
supported, although libsodium must be installed to use it. Install
|
|
||||||
vllm with tensorizer support using `pip install vllm[tensorizer]`.
|
|
||||||
|
|
||||||
To serialize a model, you can run something like this:
|
|
||||||
|
|
||||||
python tensorize_vllm_model.py \
|
|
||||||
--model EleutherAI/gpt-j-6B \
|
|
||||||
--dtype float16 \
|
|
||||||
serialize \
|
|
||||||
--serialized-directory s3://my-bucket/ \
|
|
||||||
--suffix vllm
|
|
||||||
|
|
||||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
|
||||||
and saves it to your S3 bucket. A local directory can also be used.
|
|
||||||
|
|
||||||
You can also encrypt the model weights with a randomly-generated key by
|
|
||||||
providing a `--keyfile` argument.
|
|
||||||
|
|
||||||
To deserialize a model, you can run something like this:
|
|
||||||
|
|
||||||
python tensorize_vllm_model.py \
|
|
||||||
--model EleutherAI/gpt-j-6B \
|
|
||||||
--dtype float16 \
|
|
||||||
deserialize \
|
|
||||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
|
||||||
|
|
||||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
|
||||||
To provide S3 credentials, you can provide `--s3-access-key-id` and
|
|
||||||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
|
|
||||||
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
|
|
||||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
|
||||||
|
|
||||||
|
|
||||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
|
||||||
they were serialized with encryption.
|
|
||||||
|
|
||||||
For more information on the available arguments, run
|
|
||||||
`python tensorize_vllm_model.py --help`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="An example script that can be used to serialize and "
|
|
||||||
"deserialize vLLM models. These models "
|
|
||||||
"can be loaded using tensorizer directly to the GPU "
|
|
||||||
"extremely quickly. Tensor encryption and decryption is "
|
|
||||||
"also supported, although libsodium must be installed to "
|
|
||||||
"use it.")
|
|
||||||
parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser))
|
|
||||||
subparsers = parser.add_subparsers(dest='command')
|
|
||||||
|
|
||||||
serialize_parser = subparsers.add_parser(
|
|
||||||
'serialize', help="Serialize a model to `--serialized-directory`")
|
|
||||||
|
|
||||||
serialize_parser.add_argument(
|
|
||||||
"--suffix",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
help=(
|
|
||||||
"The suffix to append to the serialized model directory, which is "
|
|
||||||
"used to construct the location of the serialized model tensors, "
|
|
||||||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
|
||||||
"`--suffix` is `v1`, the serialized model tensors will be "
|
|
||||||
"saved to "
|
|
||||||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
|
||||||
"If none is provided, a random UUID will be used."))
|
|
||||||
serialize_parser.add_argument(
|
|
||||||
"--serialized-directory",
|
|
||||||
type=str,
|
|
||||||
required=True)
|
|
||||||
|
|
||||||
serialize_parser.add_argument(
|
|
||||||
"--keyfile",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
help=("Encrypt the model weights with a randomly-generated binary key,"
|
|
||||||
" and save the key at this path"))
|
|
||||||
|
|
||||||
deserialize_parser = subparsers.add_parser(
|
|
||||||
'deserialize',
|
|
||||||
help=("Deserialize a model from `--path-to-tensors`"
|
|
||||||
" to verify it can be loaded and used."))
|
|
||||||
|
|
||||||
deserialize_parser.add_argument(
|
|
||||||
"--path-to-tensors",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The local path or S3 URI to the model tensors to deserialize. ")
|
|
||||||
|
|
||||||
deserialize_parser.add_argument(
|
|
||||||
"--keyfile",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
help=("Path to a binary key to use to decrypt the model weights,"
|
|
||||||
" if the model was serialized with encryption"))
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def make_model_contiguous(model):
|
|
||||||
# Ensure tensors are saved in memory contiguously
|
|
||||||
for param in model.parameters():
|
|
||||||
param.data = param.data.contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|
||||||
architectures = getattr(config, "architectures", [])
|
|
||||||
for arch in architectures:
|
|
||||||
model_cls = ModelRegistry.load_model_cls(arch)
|
|
||||||
if model_cls is not None:
|
|
||||||
return model_cls
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
|
||||||
|
|
||||||
|
|
||||||
def serialize():
|
|
||||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
|
||||||
dataclasses.fields(EngineArgs)}
|
|
||||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
|
||||||
|
|
||||||
model = (engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|
||||||
encryption_params = EncryptionParams.random() if keyfile else None
|
|
||||||
if keyfile:
|
|
||||||
with _write_stream(keyfile) as stream:
|
|
||||||
stream.write(encryption_params.key)
|
|
||||||
|
|
||||||
with _write_stream(model_path) as stream:
|
|
||||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
|
||||||
serializer.write_module(model)
|
|
||||||
serializer.close()
|
|
||||||
|
|
||||||
print("Serialization complete. Model tensors saved to", model_path)
|
|
||||||
if keyfile:
|
|
||||||
print("Key saved to", keyfile)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize():
|
|
||||||
config = AutoConfig.from_pretrained(model_ref)
|
|
||||||
|
|
||||||
with no_init_or_tensor():
|
|
||||||
model_class = _get_vllm_model_architecture(config)
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
before_mem = get_mem_usage()
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
if keyfile:
|
|
||||||
with _read_stream(keyfile) as stream:
|
|
||||||
key = stream.read()
|
|
||||||
decryption_params = DecryptionParams.from_key(key)
|
|
||||||
tensorizer_args.deserializer_params['encryption'] = \
|
|
||||||
decryption_params
|
|
||||||
|
|
||||||
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
|
||||||
stream, **tensorizer_args.deserializer_params) as deserializer:
|
|
||||||
deserializer.load_into_module(model)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
# Brag about how fast we are.
|
|
||||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
|
||||||
duration = end - start
|
|
||||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
|
||||||
after_mem = get_mem_usage()
|
|
||||||
print(
|
|
||||||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
|
||||||
)
|
|
||||||
print(f"Memory usage before: {before_mem}")
|
|
||||||
print(f"Memory usage after: {after_mem}")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
|
||||||
or None)
|
|
||||||
s3_secret_access_key = (args.s3_secret_access_key
|
|
||||||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
|
||||||
|
|
||||||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
|
||||||
|
|
||||||
_read_stream, _write_stream = (partial(
|
|
||||||
stream_io.open_stream,
|
|
||||||
mode=mode,
|
|
||||||
s3_access_key_id=s3_access_key_id,
|
|
||||||
s3_secret_access_key=s3_secret_access_key,
|
|
||||||
s3_endpoint=s3_endpoint,
|
|
||||||
) for mode in ("rb", "wb+"))
|
|
||||||
|
|
||||||
model_ref = args.model
|
|
||||||
|
|
||||||
model_name = model_ref.split("/")[1]
|
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
||||||
os.environ["MASTER_PORT"] = "8080"
|
|
||||||
|
|
||||||
init_distributed_environment(world_size=1, rank=0, local_rank=0)
|
|
||||||
initialize_model_parallel()
|
|
||||||
|
|
||||||
keyfile = args.keyfile if args.keyfile else None
|
|
||||||
|
|
||||||
if args.command == "serialize":
|
|
||||||
input_dir = args.serialized_directory.rstrip('/')
|
|
||||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
|
||||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
|
||||||
model_path = f"{base_path}/model.tensors"
|
|
||||||
serialize()
|
|
||||||
elif args.command == "deserialize":
|
|
||||||
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
|
||||||
model_path = args.path_to_tensors
|
|
||||||
deserialize()
|
|
||||||
else:
|
|
||||||
raise ValueError("Either serialize or deserialize must be specified.")
|
|
||||||
@ -10,12 +10,19 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.model_executor.model_loader.tensorizer import (
|
# yapf: disable
|
||||||
EncryptionParams, TensorizerConfig, TensorSerializer,
|
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||||
is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
|
TensorSerializer,
|
||||||
|
is_vllm_tensorized,
|
||||||
|
load_with_tensorizer,
|
||||||
|
open_stream,
|
||||||
|
serialize_vllm_model)
|
||||||
|
|
||||||
from ..utils import ServerRunner
|
from ..utils import ServerRunner
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this docstring
|
||||||
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@ -40,7 +47,7 @@ def is_curl_installed():
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def tensorizer_config():
|
def tensorizer_config():
|
||||||
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
|
config = TensorizerConfig(tensorizer_uri="vllm")
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@ -59,47 +66,6 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
|||||||
assert result == mock_agent_instance.deserialize.return_value
|
assert result == mock_agent_instance.deserialize.return_value
|
||||||
|
|
||||||
|
|
||||||
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
|
|
||||||
tensorizer_config.vllm_tensorized = True
|
|
||||||
|
|
||||||
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
|
|
||||||
tensorizer_config.vllm_tensorized = False
|
|
||||||
|
|
||||||
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
|
|
||||||
vllm_model = vllm_runner(model_ref)
|
|
||||||
model_path = tmp_path / (model_ref + ".tensors")
|
|
||||||
outputs = vllm_model.generate(prompts, sampling_params)
|
|
||||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
with open_stream(model_path, "wb+") as stream:
|
|
||||||
serializer = TensorSerializer(stream)
|
|
||||||
serializer.write_module(model)
|
|
||||||
del vllm_model, model
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
loaded_vllm_model = vllm_runner(
|
|
||||||
model_ref,
|
|
||||||
load_format="tensorizer",
|
|
||||||
model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path,
|
|
||||||
num_readers=1,
|
|
||||||
vllm_tensorized=True),
|
|
||||||
)
|
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
|
||||||
assert outputs == deserialized_outputs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
def test_can_deserialize_s3(vllm_runner):
|
def test_can_deserialize_s3(vllm_runner):
|
||||||
model_ref = "EleutherAI/pythia-1.4b"
|
model_ref = "EleutherAI/pythia-1.4b"
|
||||||
@ -110,7 +76,6 @@ def test_can_deserialize_s3(vllm_runner):
|
|||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=tensorized_path,
|
tensorizer_uri=tensorized_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
vllm_tensorized=False,
|
|
||||||
s3_endpoint="object.ord1.coreweave.com",
|
s3_endpoint="object.ord1.coreweave.com",
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -126,29 +91,26 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
|||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
key_path = tmp_path / (model_ref + ".key")
|
key_path = tmp_path / (model_ref + ".key")
|
||||||
outputs = vllm_model.generate(prompts, sampling_params)
|
outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|
||||||
encryption_params = EncryptionParams.random()
|
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
|
||||||
with open_stream(model_path, "wb+") as stream:
|
serialize_vllm_model(vllm_model.model.llm_engine,
|
||||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
config_for_serializing,
|
||||||
serializer.write_module(model)
|
encryption_key_path=key_path)
|
||||||
with open_stream(key_path, "wb+") as stream:
|
|
||||||
stream.write(encryption_params.key)
|
del vllm_model
|
||||||
del vllm_model, model
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
loaded_vllm_model = vllm_runner(model_ref,
|
|
||||||
|
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||||
|
encryption_keyfile=key_path)
|
||||||
|
|
||||||
|
loaded_vllm_model = vllm_runner(
|
||||||
|
model_ref,
|
||||||
load_format="tensorizer",
|
load_format="tensorizer",
|
||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=config_for_deserializing)
|
||||||
tensorizer_uri=model_path,
|
|
||||||
encryption_keyfile=key_path,
|
|
||||||
num_readers=1,
|
|
||||||
vllm_tensorized=True))
|
|
||||||
|
|
||||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
|
||||||
assert outputs == deserialized_outputs
|
assert outputs == deserialized_outputs
|
||||||
|
|
||||||
|
|
||||||
@ -169,7 +131,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
|||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=model_path,
|
tensorizer_uri=model_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
vllm_tensorized=False))
|
))
|
||||||
|
|
||||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||||
prompts, max_tokens=max_tokens)
|
prompts, max_tokens=max_tokens)
|
||||||
@ -190,12 +152,11 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
# Serialize model before deserializing and binding LoRA adapters
|
# Serialize model before deserializing and binding LoRA adapters
|
||||||
vllm_model = vllm_runner(model_ref, )
|
vllm_model = vllm_runner(model_ref, )
|
||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
serialize_vllm_model(vllm_model.model.llm_engine,
|
||||||
with open_stream(model_path, "wb+") as stream:
|
TensorizerConfig(tensorizer_uri=model_path))
|
||||||
serializer = TensorSerializer(stream)
|
|
||||||
serializer.write_module(model)
|
del vllm_model
|
||||||
del vllm_model, model
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
loaded_vllm_model = vllm_runner(
|
loaded_vllm_model = vllm_runner(
|
||||||
@ -204,7 +165,6 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=model_path,
|
tensorizer_uri=model_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
vllm_tensorized=True,
|
|
||||||
),
|
),
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_loras=1,
|
max_loras=1,
|
||||||
@ -220,58 +180,28 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
|
|
||||||
def test_load_without_tensorizer_load_format(vllm_runner):
|
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
vllm_runner(model_ref,
|
vllm_runner(
|
||||||
model_loader_extra_config=TensorizerConfig(
|
model_ref,
|
||||||
tensorizer_uri="test", vllm_tensorized=False))
|
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||||
def test_tensorize_vllm_model(tmp_path):
|
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||||
# Test serialize command
|
|
||||||
serialize_args = [
|
|
||||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
|
||||||
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
|
|
||||||
"--suffix", "tests"
|
|
||||||
]
|
|
||||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
|
||||||
print(result.stdout) # Print the output of the serialize command
|
|
||||||
|
|
||||||
assert result.returncode == 0, (f"Serialize command failed with output:"
|
|
||||||
f"\n{result.stdout}\n{result.stderr}")
|
|
||||||
|
|
||||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
|
||||||
|
|
||||||
# Test deserialize command
|
|
||||||
deserialize_args = [
|
|
||||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
|
||||||
"--dtype", "float16", "deserialize", "--path-to-tensors",
|
|
||||||
path_to_tensors
|
|
||||||
]
|
|
||||||
result = subprocess.run(deserialize_args, capture_output=True, text=True)
|
|
||||||
assert result.returncode == 0, (f"Deserialize command failed with output:"
|
|
||||||
f"\n{result.stdout}\n{result.stderr}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
|
||||||
def test_openai_apiserver_with_tensorizer(tmp_path):
|
|
||||||
## Serialize model
|
## Serialize model
|
||||||
serialize_args = [
|
vllm_model = vllm_runner(model_ref, )
|
||||||
"python3", tensorize_model_for_testing_script, "--model", model_ref,
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
"--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
|
|
||||||
"--suffix", "tests"
|
|
||||||
]
|
|
||||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
|
||||||
print(result.stdout) # Print the output of the serialize command
|
|
||||||
|
|
||||||
assert result.returncode == 0, (f"Serialize command failed with output:"
|
serialize_vllm_model(vllm_model.model.llm_engine,
|
||||||
f"\n{result.stdout}\n{result.stderr}")
|
TensorizerConfig(tensorizer_uri=model_path))
|
||||||
|
|
||||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
|
||||||
model_loader_extra_config = {
|
model_loader_extra_config = {
|
||||||
"tensorizer_uri": path_to_tensors,
|
"tensorizer_uri": str(model_path),
|
||||||
"vllm_tensorized": True
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
del vllm_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
## Start OpenAI API server
|
## Start OpenAI API server
|
||||||
openai_args = [
|
openai_args = [
|
||||||
"--model", model_ref, "--dtype", "float16", "--load-format",
|
"--model", model_ref, "--dtype", "float16", "--load-format",
|
||||||
@ -304,10 +234,10 @@ def test_openai_apiserver_with_tensorizer(tmp_path):
|
|||||||
|
|
||||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
vllm_runner(model_ref,
|
vllm_runner(
|
||||||
|
model_ref,
|
||||||
load_format="safetensors",
|
load_format="safetensors",
|
||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
|
||||||
tensorizer_uri="test", vllm_tensorized=False))
|
|
||||||
|
|
||||||
|
|
||||||
def test_tensorizer_with_tp(vllm_runner):
|
def test_tensorizer_with_tp(vllm_runner):
|
||||||
@ -321,8 +251,29 @@ def test_tensorizer_with_tp(vllm_runner):
|
|||||||
model_loader_extra_config=TensorizerConfig(
|
model_loader_extra_config=TensorizerConfig(
|
||||||
tensorizer_uri=tensorized_path,
|
tensorizer_uri=tensorized_path,
|
||||||
num_readers=1,
|
num_readers=1,
|
||||||
vllm_tensorized=False,
|
|
||||||
s3_endpoint="object.ord1.coreweave.com",
|
s3_endpoint="object.ord1.coreweave.com",
|
||||||
),
|
),
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
||||||
|
model_ref = "facebook/opt-125m"
|
||||||
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
config = TensorizerConfig(tensorizer_uri=str(model_path))
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model_ref)
|
||||||
|
outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
|
serialize_vllm_model(vllm_model.model.llm_engine, config)
|
||||||
|
|
||||||
|
assert is_vllm_tensorized(config)
|
||||||
|
del vllm_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
loaded_vllm_model = vllm_runner(model_ref,
|
||||||
|
load_format="tensorizer",
|
||||||
|
model_loader_extra_config=config)
|
||||||
|
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
assert outputs == deserialized_outputs
|
||||||
|
|||||||
@ -167,8 +167,8 @@ class EngineArgs:
|
|||||||
'* "dummy" will initialize the weights with random values, '
|
'* "dummy" will initialize the weights with random values, '
|
||||||
'which is mainly for profiling.\n'
|
'which is mainly for profiling.\n'
|
||||||
'* "tensorizer" will load the weights using tensorizer from '
|
'* "tensorizer" will load the weights using tensorizer from '
|
||||||
'CoreWeave which assumes tensorizer_uri is set to the location of '
|
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
|
||||||
'the serialized weights.')
|
'section for more information.\n')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@ -145,7 +145,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
|
|
||||||
# S3 access information, used for tensorizer to load model from S3
|
# S3 access information, used for tensorizer to load model from S3
|
||||||
"S3_ACCESS_KEY_ID":
|
"S3_ACCESS_KEY_ID":
|
||||||
lambda: os.environ.get("S3_ACCESS_KEY", None),
|
lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
|
||||||
"S3_SECRET_ACCESS_KEY":
|
"S3_SECRET_ACCESS_KEY":
|
||||||
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
|
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
|
||||||
"S3_ENDPOINT_URL":
|
"S3_ENDPOINT_URL":
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.model_loader.tensorizer import (
|
from vllm.model_executor.model_loader.tensorizer import (
|
||||||
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
|
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
|
||||||
tensorizer_weights_iterator)
|
tensorizer_weights_iterator)
|
||||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||||
set_default_torch_dtype)
|
set_default_torch_dtype)
|
||||||
@ -291,7 +291,7 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||||
return tensorizer_weights_iterator(tensorizer_args)
|
return tensorizer_weights_iterator(tensorizer_args)
|
||||||
|
|
||||||
def _load_model_unserialized(
|
def _load_model_serialized_cpu(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
@ -299,11 +299,12 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Load an unserialized model with tensorizer.
|
"""Load a serialized model with tensorizer to the CPU.
|
||||||
|
|
||||||
Unserialized here means "not serialized with tensorizer". This
|
This is only necessary when the model isn't vLLM-tensorized (see
|
||||||
should still be faster than default HuggingFace loading, but will
|
examples/tensorize_vllm_model.py) This should still be faster than
|
||||||
be slower than loading a tensorizer-serialized model.
|
default HuggingFace loading, but will be slower than loading a
|
||||||
|
vLLM-tensorized model.
|
||||||
"""
|
"""
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
@ -324,8 +325,9 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Load a serialized model with tensorizer.
|
"""Load a serialized model with tensorizer.
|
||||||
|
|
||||||
See the examples/tensorize_vllm_model.py example "
|
Expects a vLLM-tensorized model. See the
|
||||||
script for serializing vLLM models."""
|
examples/tensorize_vllm_model.py example script
|
||||||
|
for serializing vLLM models."""
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model_class = get_model_architecture(model_config)[0]
|
model_class = get_model_architecture(model_config)[0]
|
||||||
@ -353,12 +355,12 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
cache_config: CacheConfig) -> nn.Module:
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
self._verify_config(model_config, parallel_config)
|
self._verify_config(model_config, parallel_config)
|
||||||
|
|
||||||
if is_vllm_serialized_tensorizer(self.tensorizer_config):
|
if is_vllm_tensorized(self.tensorizer_config):
|
||||||
return self._load_model_serialized(model_config, device_config,
|
return self._load_model_serialized(model_config, device_config,
|
||||||
lora_config,
|
lora_config,
|
||||||
vision_language_config,
|
vision_language_config,
|
||||||
cache_config)
|
cache_config)
|
||||||
return self._load_model_unserialized(model_config, device_config,
|
return self._load_model_serialized_cpu(model_config, device_config,
|
||||||
lora_config,
|
lora_config,
|
||||||
vision_language_config,
|
vision_language_config,
|
||||||
cache_config)
|
cache_config)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import Generator, Optional, Tuple, Type, Union
|
from typing import Generator, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,6 +14,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ModelConfig, ParallelConfig
|
from vllm.config import ModelConfig, ParallelConfig
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -27,6 +29,11 @@ try:
|
|||||||
from tensorizer.stream_io import open_stream
|
from tensorizer.stream_io import open_stream
|
||||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||||
no_init_or_tensor)
|
no_init_or_tensor)
|
||||||
|
|
||||||
|
_read_stream, _write_stream = (partial(
|
||||||
|
open_stream,
|
||||||
|
mode=mode,
|
||||||
|
) for mode in ("rb", "wb+"))
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
tensorizer_error_msg = str(e)
|
tensorizer_error_msg = str(e)
|
||||||
|
|
||||||
@ -43,7 +50,7 @@ logger = init_logger(__name__)
|
|||||||
class TensorizerConfig:
|
class TensorizerConfig:
|
||||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||||
str, bytes, os.PathLike, int]
|
str, bytes, os.PathLike, int]
|
||||||
vllm_tensorized: bool
|
vllm_tensorized: Optional[bool] = False
|
||||||
verify_hash: Optional[bool] = False
|
verify_hash: Optional[bool] = False
|
||||||
num_readers: Optional[int] = None
|
num_readers: Optional[int] = None
|
||||||
encryption_keyfile: Optional[str] = None
|
encryption_keyfile: Optional[str] = None
|
||||||
@ -93,17 +100,11 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig,
|
|||||||
return tensorizer.deserialize()
|
return tensorizer.deserialize()
|
||||||
|
|
||||||
|
|
||||||
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
|
|
||||||
if tensorizer_config is None:
|
|
||||||
return False
|
|
||||||
return tensorizer_config.vllm_tensorized
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TensorizerArgs:
|
class TensorizerArgs:
|
||||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||||
str, bytes, os.PathLike, int]
|
str, bytes, os.PathLike, int]
|
||||||
vllm_tensorized: bool
|
vllm_tensorized: Optional[bool] = False
|
||||||
verify_hash: Optional[bool] = False
|
verify_hash: Optional[bool] = False
|
||||||
num_readers: Optional[int] = None
|
num_readers: Optional[int] = None
|
||||||
encryption_keyfile: Optional[str] = None
|
encryption_keyfile: Optional[str] = None
|
||||||
@ -121,7 +122,9 @@ class TensorizerArgs:
|
|||||||
vLLM model. This is used to determine the behavior of the
|
vLLM model. This is used to determine the behavior of the
|
||||||
TensorDeserializer when loading tensors from a serialized model.
|
TensorDeserializer when loading tensors from a serialized model.
|
||||||
It is far faster to deserialize a vLLM model as it utilizes
|
It is far faster to deserialize a vLLM model as it utilizes
|
||||||
tensorizer's optimized GPU loading.
|
tensorizer's optimized GPU loading. Note that this is now
|
||||||
|
deprecated, as serialized vLLM models are now automatically
|
||||||
|
inferred as vLLM models.
|
||||||
verify_hash: If True, the hashes of each tensor will be verified against
|
verify_hash: If True, the hashes of each tensor will be verified against
|
||||||
the hashes stored in the metadata. A `HashMismatchError` will be
|
the hashes stored in the metadata. A `HashMismatchError` will be
|
||||||
raised if any of the hashes do not match.
|
raised if any of the hashes do not match.
|
||||||
@ -158,6 +161,7 @@ class TensorizerArgs:
|
|||||||
"encryption": self.encryption_keyfile,
|
"encryption": self.encryption_keyfile,
|
||||||
"num_readers": self.num_readers
|
"num_readers": self.num_readers
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.encryption_keyfile:
|
if self.encryption_keyfile:
|
||||||
with open_stream(
|
with open_stream(
|
||||||
self.encryption_keyfile,
|
self.encryption_keyfile,
|
||||||
@ -177,7 +181,14 @@ class TensorizerArgs:
|
|||||||
'tensorizer options',
|
'tensorizer options',
|
||||||
description=('Options for configuring the behavior of the'
|
description=('Options for configuring the behavior of the'
|
||||||
' tensorizer deserializer when '
|
' tensorizer deserializer when '
|
||||||
'--load-format=tensorizer'))
|
'load_format=tensorizer is specified when '
|
||||||
|
'initializing an LLMEngine, either via the CLI '
|
||||||
|
'when running the vLLM OpenAI inference server '
|
||||||
|
'with a JSON string passed to '
|
||||||
|
'--model-loader-extra-config or as arguments given '
|
||||||
|
'to TensorizerConfig when passed to '
|
||||||
|
'model_loader_extra_config in the constructor '
|
||||||
|
'for LLMEngine.'))
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--tensorizer-uri",
|
"--tensorizer-uri",
|
||||||
@ -222,13 +233,6 @@ class TensorizerArgs:
|
|||||||
help="The endpoint for the S3 bucket. Can also be set via the "
|
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||||
"S3_ENDPOINT_URL environment variable.",
|
"S3_ENDPOINT_URL environment variable.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
|
||||||
"--vllm-tensorized",
|
|
||||||
action="store_true",
|
|
||||||
help="If enabled, indicates that the serialized model is a vLLM "
|
|
||||||
"model. This is used to determine the behavior of the "
|
|
||||||
"TensorDeserializer when loading tensors from a "
|
|
||||||
"serialized model.")
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -322,10 +326,9 @@ class TensorizerAgent:
|
|||||||
"""
|
"""
|
||||||
before_mem = get_mem_usage()
|
before_mem = get_mem_usage()
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
with open_stream(
|
with _read_stream(
|
||||||
self.tensorizer_args.tensorizer_uri,
|
self.tensorizer_config.tensorizer_uri,
|
||||||
mode="rb",
|
**self.tensorizer_args.stream_params
|
||||||
**self.tensorizer_args.stream_params,
|
|
||||||
) as stream, TensorDeserializer(
|
) as stream, TensorDeserializer(
|
||||||
stream,
|
stream,
|
||||||
dtype=self.tensorizer_config.dtype,
|
dtype=self.tensorizer_config.dtype,
|
||||||
@ -345,6 +348,7 @@ class TensorizerAgent:
|
|||||||
|
|
||||||
self._check_tensors_on_meta_device()
|
self._check_tensors_on_meta_device()
|
||||||
self._resize_lora_embeddings()
|
self._resize_lora_embeddings()
|
||||||
|
del self.model.vllm_tensorized_marker
|
||||||
return self.model.eval()
|
return self.model.eval()
|
||||||
|
|
||||||
|
|
||||||
@ -366,3 +370,63 @@ def tensorizer_weights_iterator(
|
|||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
del state
|
del state
|
||||||
|
|
||||||
|
|
||||||
|
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||||
|
"""
|
||||||
|
Infer if the model is a vLLM model by checking the weights for
|
||||||
|
a vLLM tensorized marker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensorizer_config: The TensorizerConfig object containing the
|
||||||
|
tensorizer_uri to the serialized model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model is a vLLM model, False otherwise.
|
||||||
|
"""
|
||||||
|
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||||
|
deserializer = TensorDeserializer(open_stream(
|
||||||
|
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
|
||||||
|
**tensorizer_args.deserializer_params,
|
||||||
|
lazy_load=True)
|
||||||
|
if tensorizer_config.vllm_tensorized:
|
||||||
|
logger.warning(
|
||||||
|
"Please note that newly serialized vLLM models are automatically "
|
||||||
|
"inferred as vLLM models, so setting vllm_tensorized=True is "
|
||||||
|
"only necessary for models serialized prior to this change.")
|
||||||
|
return True
|
||||||
|
if (".vllm_tensorized_marker" in deserializer):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module:
|
||||||
|
model = (engine.model_executor.driver_worker.model_runner.model)
|
||||||
|
model.register_parameter(
|
||||||
|
"vllm_tensorized_marker",
|
||||||
|
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_vllm_model(engine: "LLMEngine",
|
||||||
|
tensorizer_config : TensorizerConfig,
|
||||||
|
encryption_key_path: Optional[str] = None) \
|
||||||
|
-> nn.Module:
|
||||||
|
|
||||||
|
model = get_pretensorized_vllm_model(engine)
|
||||||
|
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||||
|
encryption_params = None
|
||||||
|
if encryption_key_path is not None:
|
||||||
|
encryption_params = EncryptionParams.random()
|
||||||
|
with _write_stream(encryption_key_path,
|
||||||
|
**tensorizer_args.stream_params) as stream:
|
||||||
|
stream.write(encryption_params.key)
|
||||||
|
|
||||||
|
with _write_stream(tensorizer_args.tensorizer_uri,
|
||||||
|
**tensorizer_args.stream_params) as stream:
|
||||||
|
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||||
|
serializer.write_module(model)
|
||||||
|
serializer.close()
|
||||||
|
logger.info("Successfully serialized model to %s",
|
||||||
|
str(tensorizer_args.tensorizer_uri))
|
||||||
|
return model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user