[torchao] safetensors integration (#25969)

Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
liangel-02 2025-10-07 19:12:35 -07:00 committed by GitHub
parent f80e7866c0
commit b32260ab85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 60 additions and 0 deletions

View File

@ -216,5 +216,22 @@ def test_reload_weights():
# print("-" * 60)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = (
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
)
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
assert output
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -59,6 +59,10 @@ class LoadConfig:
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
- "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
"""
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
from importlib.util import find_spec
from typing import Any, Optional
import torch
import torch.nn.functional as F
from packaging import version
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
def torchao_version_at_least(torchao_version: str) -> bool:
if find_spec("torchao"):
try:
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
torchao_version
):
return True
except (ImportError, version.InvalidVersion):
return False
return False
def should_skip(prefix: str, skip_modules: list[str]) -> bool:
"""
Robust skipping logic:

View File

@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and

View File

@ -54,6 +54,8 @@ except ImportError:
SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
@ -602,6 +604,23 @@ def safetensors_weights_iterator(
with open(st_file, "rb") as f:
state_dict = load(f.read())
yield from state_dict.items()
elif safetensors_load_strategy == "torchao":
if not torchao_version_at_least("0.14.0"):
raise ValueError(
"Please use torchao version >= 0.14.0 \
to load torchao safetensors checkpoint"
)
from torchao.prototype.safetensors.safetensors_support import (
unflatten_tensor_state_dict,
)
with safe_open(st_file, framework="pt") as f:
state_dict = {}
for name in f.keys(): # noqa: SIM118
state_dict[name] = f.get_tensor(name)
metadata = f.metadata()
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
yield from updated_state_dict.items()
else:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118