diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index 8182896..f340f96 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -24,12 +24,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): try: from cublas_ops import cublas_half_matmul linear_ops = cublas_half_matmul + setattr(model, "cublas_half_matmul", True) print("Using cublas_ops") except: raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops") else: linear_ops = F.linear - pass + setattr(model, "cublas_half_matmul", False) + quant_keys = [] for key in state_dict.keys(): if key.endswith(".Q4_0_qweight"): @@ -52,14 +54,6 @@ def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): model.to_empty(device=device) model.load_state_dict(state_dict, strict=False) - try: - if linear_ops == cublas_half_matmul: - setattr(model, "cublas_half_matmul", True) - else: - setattr(model, "cublas_half_matmul", False) - except: - setattr(model, "cublas_half_matmul", False) - pass return model