From e82e6ee3f7e6adcc2aa009e8f46ea5dab3c6459f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 26 Oct 2024 18:10:45 +0300 Subject: [PATCH] Update mz_gguf_loader.py --- mz_gguf_loader.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index 8182896..f340f96 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -24,12 +24,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): try: from cublas_ops import cublas_half_matmul linear_ops = cublas_half_matmul + setattr(model, "cublas_half_matmul", True) print("Using cublas_ops") except: raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops") else: linear_ops = F.linear - pass + setattr(model, "cublas_half_matmul", False) + quant_keys = [] for key in state_dict.keys(): if key.endswith(".Q4_0_qweight"): @@ -52,14 +54,6 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): model.to_empty(device=device) model.load_state_dict(state_dict, strict=False) - try: - if linear_ops == cublas_half_matmul: - setattr(model, "cublas_half_matmul", True) - else: - setattr(model, "cublas_half_matmul", False) - except: - setattr(model, "cublas_half_matmul", False) - pass return model