Update mz_gguf_loader.py
This commit is contained in:
parent
c5c136cb11
commit
e82e6ee3f7
@ -24,12 +24,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
|
|||||||
try:
|
try:
|
||||||
from cublas_ops import cublas_half_matmul
|
from cublas_ops import cublas_half_matmul
|
||||||
linear_ops = cublas_half_matmul
|
linear_ops = cublas_half_matmul
|
||||||
|
setattr(model, "cublas_half_matmul", True)
|
||||||
print("Using cublas_ops")
|
print("Using cublas_ops")
|
||||||
except:
|
except:
|
||||||
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
|
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
|
||||||
else:
|
else:
|
||||||
linear_ops = F.linear
|
linear_ops = F.linear
|
||||||
pass
|
setattr(model, "cublas_half_matmul", False)
|
||||||
|
|
||||||
quant_keys = []
|
quant_keys = []
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
if key.endswith(".Q4_0_qweight"):
|
if key.endswith(".Q4_0_qweight"):
|
||||||
@ -52,14 +54,6 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
|
|||||||
|
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
try:
|
|
||||||
if linear_ops == cublas_half_matmul:
|
|
||||||
setattr(model, "cublas_half_matmul", True)
|
|
||||||
else:
|
|
||||||
setattr(model, "cublas_half_matmul", False)
|
|
||||||
except:
|
|
||||||
setattr(model, "cublas_half_matmul", False)
|
|
||||||
pass
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user