[torchao] fix safetensors for sharding (#28169)

Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
liangel-02 2025-11-19 19:39:45 -05:00 committed by GitHub
parent 9ccef8e333
commit 1d642872a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 11 deletions

View File

@ -225,13 +225,12 @@ def test_reload_weights():
@pytest.mark.skip( @pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly" reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now" "torchao tests that requires newer versions (0.15.0.dev+) for now"
) )
def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner): def test_safetensors_model_loading_with_params(vllm_runner):
torch._dynamo.reset() torch._dynamo.reset()
model_name = ( # using this model to test safetensors loading with file sharding
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors" model_name = "torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors"
)
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm: with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4) output = llm.generate_greedy(["The capital of France is"], max_tokens=4)

View File

@ -279,7 +279,7 @@ class DefaultModelLoader(BaseModelLoader):
if ( if (
hasattr(quant_config, "is_checkpoint_torchao_serialized") hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.14.0") and torchao_version_at_least("0.15.0")
): ):
self.load_config.safetensors_load_strategy = "torchao" self.load_config.safetensors_load_strategy = "torchao"

View File

@ -595,6 +595,9 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager": if safetensors_load_strategy == "eager":
loading_desc += " (eager)" loading_desc += " (eager)"
state_dict = {}
leftover_state_dict: dict[str, torch.Tensor] = {}
for st_file in tqdm( for st_file in tqdm(
hf_weights_files, hf_weights_files,
desc=loading_desc, desc=loading_desc,
@ -606,9 +609,11 @@ def safetensors_weights_iterator(
state_dict = load(f.read()) state_dict = load(f.read())
yield from state_dict.items() yield from state_dict.items()
elif safetensors_load_strategy == "torchao": elif safetensors_load_strategy == "torchao":
if not torchao_version_at_least("0.14.0"): # we can't load flattened torchao tensor subclasses directly into the model
# instead we reconstruct the subclasses here before returning
if not torchao_version_at_least("0.15.0"):
raise ValueError( raise ValueError(
"Please use torchao version >= 0.14.0 \ "Please use torchao version >= 0.15.0 \
to load torchao safetensors checkpoint" to load torchao safetensors checkpoint"
) )
from torchao.prototype.safetensors.safetensors_support import ( from torchao.prototype.safetensors.safetensors_support import (
@ -616,12 +621,20 @@ def safetensors_weights_iterator(
) )
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
state_dict = {}
for name in f.keys(): # noqa: SIM118 for name in f.keys(): # noqa: SIM118
state_dict[name] = f.get_tensor(name) state_dict[name] = f.get_tensor(name)
# update with leftover tensor data from previous iteration, if any
state_dict.update(leftover_state_dict)
metadata = f.metadata() metadata = f.metadata()
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata) # due to sharded checkpoints, we are not guaranteed that we have all
yield from updated_state_dict.items() # tensor subclass data on one file
# state_dict has the leftover data from this step and we wait for
# missing information to be provided in a future iteration
unflattened_state_dict, leftover_state_dict = (
unflatten_tensor_state_dict(state_dict, metadata)
)
yield from unflattened_state_dict.items()
else: else:
with safe_open(st_file, framework="pt") as f: with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118 for name in f.keys(): # noqa: SIM118