Update mz_gguf_loader.py

This commit is contained in:
kijai 2024-10-26 18:10:45 +03:00
parent c5c136cb11
commit e82e6ee3f7

View File

@ -24,12 +24,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
try:
from cublas_ops import cublas_half_matmul
linear_ops = cublas_half_matmul
setattr(model, "cublas_half_matmul", True)
print("Using cublas_ops")
except:
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
else:
linear_ops = F.linear
pass
setattr(model, "cublas_half_matmul", False)
quant_keys = []
for key in state_dict.keys():
if key.endswith(".Q4_0_qweight"):
@ -52,14 +54,6 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
model.to_empty(device=device)
model.load_state_dict(state_dict, strict=False)
try:
if linear_ops == cublas_half_matmul:
setattr(model, "cublas_half_matmul", True)
else:
setattr(model, "cublas_half_matmul", False)
except:
setattr(model, "cublas_half_matmul", False)
pass
return model