diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 45ee94119bbb4..d1cf7e1635960 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -216,5 +216,22 @@ def test_reload_weights(): # print("-" * 60) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = ( + "torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors" + ) + with vllm_runner(model_name=model_name, dtype="bfloat16") as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + + assert output + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/config/load.py b/vllm/config/load.py index 6aacff60157b0..23ce29e3983d8 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -59,6 +59,10 @@ class LoadConfig: This is recommended for models on network filesystems (e.g., Lustre, NFS) as it avoids inefficient random reads, significantly speeding up model initialization. However, it uses more CPU RAM. + - "torchao": Weights are loaded in upfront and then reconstructed + into torchao tensor subclasses. This is used when the checkpoint + was quantized using torchao and saved using safetensors. + Needs torchao >= 0.14.0 """ model_loader_extra_config: Union[dict, TensorizerConfig] = field( default_factory=dict diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 629d0b8630412..55eb2890bb2f7 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import json +from importlib.util import find_spec from typing import Any, Optional import torch import torch.nn.functional as F +from packaging import version from torch.nn.parameter import Parameter from vllm.logger import init_logger @@ -23,6 +26,18 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) +def torchao_version_at_least(torchao_version: str) -> bool: + if find_spec("torchao"): + try: + if version.parse(importlib.metadata.version("torchao")) >= version.parse( + torchao_version + ): + return True + except (ImportError, version.InvalidVersion): + return False + return False + + def should_skip(prefix: str, skip_modules: list[str]) -> bool: """ Robust skipping logic: diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 206b8244569f0..b1fd579da6b21 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, @@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader): ) def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + if model_config.quantization == "torchao" and torchao_version_at_least( + "0.14.0" + ): + self.load_config.safetensors_load_strategy = "torchao" weights_to_load = {name for name, _ in model.named_parameters()} # if we don't have `model.weight_metadata_and_attr_saved` defined and diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c40185c1c0840..5f83482bec3a0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -54,6 +54,8 @@ except ImportError: SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader") SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") +from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least + logger = init_logger(__name__) # use system-level temp directory for file locks, so that multiple users @@ -602,6 +604,23 @@ def safetensors_weights_iterator( with open(st_file, "rb") as f: state_dict = load(f.read()) yield from state_dict.items() + elif safetensors_load_strategy == "torchao": + if not torchao_version_at_least("0.14.0"): + raise ValueError( + "Please use torchao version >= 0.14.0 \ + to load torchao safetensors checkpoint" + ) + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) + + with safe_open(st_file, framework="pt") as f: + state_dict = {} + for name in f.keys(): # noqa: SIM118 + state_dict[name] = f.get_tensor(name) + metadata = f.metadata() + updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata) + yield from updated_state_dict.items() else: with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118