From 1d642872a27f1c6bedf28669642928cc7eec6532 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Wed, 19 Nov 2025 19:39:45 -0500 Subject: [PATCH] [torchao] fix safetensors for sharding (#28169) Signed-off-by: Angel Li --- tests/quantization/test_torchao.py | 9 ++++---- .../model_loader/default_loader.py | 2 +- .../model_loader/weight_utils.py | 23 +++++++++++++++---- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index fb8d6130c3779..f35c3973ab6e6 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -225,13 +225,12 @@ def test_reload_weights(): @pytest.mark.skip( reason="since torchao nightly is only compatible with torch nightly" "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " - "torchao tests that requires newer versions (0.14.0.dev+) for now" + "torchao tests that requires newer versions (0.15.0.dev+) for now" ) -def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner): +def test_safetensors_model_loading_with_params(vllm_runner): torch._dynamo.reset() - model_name = ( - "torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors" - ) + # using this model to test safetensors loading with file sharding + model_name = "torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors" with vllm_runner(model_name=model_name, dtype="bfloat16") as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=4) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index b80026741781f..67aa584c6bda2 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -279,7 +279,7 @@ class DefaultModelLoader(BaseModelLoader): if ( hasattr(quant_config, "is_checkpoint_torchao_serialized") and quant_config.is_checkpoint_torchao_serialized - and torchao_version_at_least("0.14.0") + and torchao_version_at_least("0.15.0") ): self.load_config.safetensors_load_strategy = "torchao" diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 89634cbf41241..4572ebe2ea11b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -595,6 +595,9 @@ def safetensors_weights_iterator( if safetensors_load_strategy == "eager": loading_desc += " (eager)" + state_dict = {} + leftover_state_dict: dict[str, torch.Tensor] = {} + for st_file in tqdm( hf_weights_files, desc=loading_desc, @@ -606,9 +609,11 @@ def safetensors_weights_iterator( state_dict = load(f.read()) yield from state_dict.items() elif safetensors_load_strategy == "torchao": - if not torchao_version_at_least("0.14.0"): + # we can't load flattened torchao tensor subclasses directly into the model + # instead we reconstruct the subclasses here before returning + if not torchao_version_at_least("0.15.0"): raise ValueError( - "Please use torchao version >= 0.14.0 \ + "Please use torchao version >= 0.15.0 \ to load torchao safetensors checkpoint" ) from torchao.prototype.safetensors.safetensors_support import ( @@ -616,12 +621,20 @@ def safetensors_weights_iterator( ) with safe_open(st_file, framework="pt") as f: - state_dict = {} for name in f.keys(): # noqa: SIM118 state_dict[name] = f.get_tensor(name) + + # update with leftover tensor data from previous iteration, if any + state_dict.update(leftover_state_dict) metadata = f.metadata() - updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata) - yield from updated_state_dict.items() + # due to sharded checkpoints, we are not guaranteed that we have all + # tensor subclass data on one file + # state_dict has the leftover data from this step and we wait for + # missing information to be provided in a future iteration + unflattened_state_dict, leftover_state_dict = ( + unflatten_tensor_state_dict(state_dict, metadata) + ) + yield from unflattened_state_dict.items() else: with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118