mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 08:27:13 +08:00
[torchao] safetensors integration (#25969)
Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
parent
f80e7866c0
commit
b32260ab85
@ -216,5 +216,22 @@ def test_reload_weights():
|
||||
# print("-" * 60)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
@pytest.mark.skip(
|
||||
reason="since torchao nightly is only compatible with torch nightly"
|
||||
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
|
||||
"torchao tests that requires newer versions (0.14.0.dev+) for now"
|
||||
)
|
||||
def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner):
|
||||
torch._dynamo.reset()
|
||||
model_name = (
|
||||
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
|
||||
)
|
||||
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
|
||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
|
||||
|
||||
assert output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -59,6 +59,10 @@ class LoadConfig:
|
||||
This is recommended for models on network filesystems (e.g., Lustre, NFS)
|
||||
as it avoids inefficient random reads, significantly speeding up model
|
||||
initialization. However, it uses more CPU RAM.
|
||||
- "torchao": Weights are loaded in upfront and then reconstructed
|
||||
into torchao tensor subclasses. This is used when the checkpoint
|
||||
was quantized using torchao and saved using safetensors.
|
||||
Needs torchao >= 0.14.0
|
||||
"""
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
default_factory=dict
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import json
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def torchao_version_at_least(torchao_version: str) -> bool:
|
||||
if find_spec("torchao"):
|
||||
try:
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
|
||||
torchao_version
|
||||
):
|
||||
return True
|
||||
except (ImportError, version.InvalidVersion):
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
|
||||
"""
|
||||
Robust skipping logic:
|
||||
|
||||
@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao" and torchao_version_at_least(
|
||||
"0.14.0"
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
|
||||
@ -54,6 +54,8 @@ except ImportError:
|
||||
SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader")
|
||||
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
|
||||
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# use system-level temp directory for file locks, so that multiple users
|
||||
@ -602,6 +604,23 @@ def safetensors_weights_iterator(
|
||||
with open(st_file, "rb") as f:
|
||||
state_dict = load(f.read())
|
||||
yield from state_dict.items()
|
||||
elif safetensors_load_strategy == "torchao":
|
||||
if not torchao_version_at_least("0.14.0"):
|
||||
raise ValueError(
|
||||
"Please use torchao version >= 0.14.0 \
|
||||
to load torchao safetensors checkpoint"
|
||||
)
|
||||
from torchao.prototype.safetensors.safetensors_support import (
|
||||
unflatten_tensor_state_dict,
|
||||
)
|
||||
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
state_dict = {}
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
state_dict[name] = f.get_tensor(name)
|
||||
metadata = f.metadata()
|
||||
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
|
||||
yield from updated_state_dict.items()
|
||||
else:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user