mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:44:57 +08:00
[torchao] fix safetensors for sharding (#28169)
Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
parent
9ccef8e333
commit
1d642872a2
@ -225,13 +225,12 @@ def test_reload_weights():
|
|||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
reason="since torchao nightly is only compatible with torch nightly"
|
reason="since torchao nightly is only compatible with torch nightly"
|
||||||
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
|
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
|
||||||
"torchao tests that requires newer versions (0.14.0.dev+) for now"
|
"torchao tests that requires newer versions (0.15.0.dev+) for now"
|
||||||
)
|
)
|
||||||
def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_runner):
|
def test_safetensors_model_loading_with_params(vllm_runner):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
model_name = (
|
# using this model to test safetensors loading with file sharding
|
||||||
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors"
|
model_name = "torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors"
|
||||||
)
|
|
||||||
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
|
with vllm_runner(model_name=model_name, dtype="bfloat16") as llm:
|
||||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
|
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
|
||||||
|
|
||||||
|
|||||||
@ -279,7 +279,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
if (
|
if (
|
||||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||||
and quant_config.is_checkpoint_torchao_serialized
|
and quant_config.is_checkpoint_torchao_serialized
|
||||||
and torchao_version_at_least("0.14.0")
|
and torchao_version_at_least("0.15.0")
|
||||||
):
|
):
|
||||||
self.load_config.safetensors_load_strategy = "torchao"
|
self.load_config.safetensors_load_strategy = "torchao"
|
||||||
|
|
||||||
|
|||||||
@ -595,6 +595,9 @@ def safetensors_weights_iterator(
|
|||||||
if safetensors_load_strategy == "eager":
|
if safetensors_load_strategy == "eager":
|
||||||
loading_desc += " (eager)"
|
loading_desc += " (eager)"
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
leftover_state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
for st_file in tqdm(
|
for st_file in tqdm(
|
||||||
hf_weights_files,
|
hf_weights_files,
|
||||||
desc=loading_desc,
|
desc=loading_desc,
|
||||||
@ -606,9 +609,11 @@ def safetensors_weights_iterator(
|
|||||||
state_dict = load(f.read())
|
state_dict = load(f.read())
|
||||||
yield from state_dict.items()
|
yield from state_dict.items()
|
||||||
elif safetensors_load_strategy == "torchao":
|
elif safetensors_load_strategy == "torchao":
|
||||||
if not torchao_version_at_least("0.14.0"):
|
# we can't load flattened torchao tensor subclasses directly into the model
|
||||||
|
# instead we reconstruct the subclasses here before returning
|
||||||
|
if not torchao_version_at_least("0.15.0"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please use torchao version >= 0.14.0 \
|
"Please use torchao version >= 0.15.0 \
|
||||||
to load torchao safetensors checkpoint"
|
to load torchao safetensors checkpoint"
|
||||||
)
|
)
|
||||||
from torchao.prototype.safetensors.safetensors_support import (
|
from torchao.prototype.safetensors.safetensors_support import (
|
||||||
@ -616,12 +621,20 @@ def safetensors_weights_iterator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with safe_open(st_file, framework="pt") as f:
|
with safe_open(st_file, framework="pt") as f:
|
||||||
state_dict = {}
|
|
||||||
for name in f.keys(): # noqa: SIM118
|
for name in f.keys(): # noqa: SIM118
|
||||||
state_dict[name] = f.get_tensor(name)
|
state_dict[name] = f.get_tensor(name)
|
||||||
|
|
||||||
|
# update with leftover tensor data from previous iteration, if any
|
||||||
|
state_dict.update(leftover_state_dict)
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
|
# due to sharded checkpoints, we are not guaranteed that we have all
|
||||||
yield from updated_state_dict.items()
|
# tensor subclass data on one file
|
||||||
|
# state_dict has the leftover data from this step and we wait for
|
||||||
|
# missing information to be provided in a future iteration
|
||||||
|
unflattened_state_dict, leftover_state_dict = (
|
||||||
|
unflatten_tensor_state_dict(state_dict, metadata)
|
||||||
|
)
|
||||||
|
yield from unflattened_state_dict.items()
|
||||||
else:
|
else:
|
||||||
with safe_open(st_file, framework="pt") as f:
|
with safe_open(st_file, framework="pt") as f:
|
||||||
for name in f.keys(): # noqa: SIM118
|
for name in f.keys(): # noqa: SIM118
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user