diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index 9130c7e..a6fdf85 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -23,8 +23,10 @@ def quantize_load_state_dict(model, state_dict, device="cpu"): for key in state_dict.keys(): if key.endswith(".Q4_0_qweight"): quant_keys.append(key.replace(".Q4_0_qweight", "")) + qtype = "Q4_0" elif key.endswith(".Q8_0_qweight"): quant_keys.append(key.replace(".Q8_0_qweight", "")) + qtype = "Q8_0" for name, module in model.named_modules(): if name in quant_keys: @@ -32,7 +34,7 @@ def quantize_load_state_dict(model, state_dict, device="cpu"): q_linear = WQLinear_GGUF.from_linear( linear=module, device=device, - qtype="Q8_0", + qtype=qtype, ) set_op_by_name(model, name, q_linear)