mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 20:04:59 +08:00
use get_tensor in safe_open (#1696)
This commit is contained in:
parent
edb305584b
commit
e946260cf3
@ -243,8 +243,8 @@ def hf_model_weights_iterator(
|
|||||||
for st_file in hf_weights_files:
|
for st_file in hf_weights_files:
|
||||||
with safe_open(st_file, framework="pt") as f:
|
with safe_open(st_file, framework="pt") as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
param = f.get_slice(name)
|
param = f.get_tensor(name)
|
||||||
yield name, convert_pyslice_to_tensor(param)
|
yield name, param
|
||||||
else:
|
else:
|
||||||
for bin_file in hf_weights_files:
|
for bin_file in hf_weights_files:
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
@ -265,12 +265,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
|||||||
tensor first.
|
tensor first.
|
||||||
"""
|
"""
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
try:
|
x = x[:]
|
||||||
x = x[:]
|
|
||||||
except IndexError:
|
|
||||||
# IndexError happens when the tensor is empty.
|
|
||||||
# transformer.h.0.attn.masked_bias is empty in some gpt2 models.
|
|
||||||
return torch.Tensor()
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user