From a51ca0e907641cf707eaee195fac89824393da4d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 01:28:09 +0300 Subject: [PATCH] Add GGUF_Q8_0 --- mz_gguf_loader.py | 34 ++++++++++++++++++---------------- nodes.py | 3 ++- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index 262f7df..b2044d3 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -19,18 +19,20 @@ class quantize_lazy_load(): def quantize_load_state_dict(model, state_dict, device="cpu"): - Q4_0_qkey = [] + quant_keys = [] for key in state_dict.keys(): if key.endswith(".Q4_0_qweight"): - Q4_0_qkey.append(key.replace(".Q4_0_qweight", "")) + quant_keys.append(key.replace(".Q4_0_qweight", "")) + elif key.endswith(".Q8_0_qweight"): + quant_keys.append(key.replace(".Q8_0_qweight", "")) for name, module in model.named_modules(): - if name in Q4_0_qkey: + if name in quant_keys: #print(name) q_linear = WQLinear_GGUF.from_linear( linear=module, device=device, - qtype="Q4_0", + qtype="Q8_0", ) set_op_by_name(model, name, q_linear) @@ -122,6 +124,9 @@ class WQLinear_GGUF(nn.Module): if self.qtype == "Q4_0": x = F.linear(x, dequantize_blocks_Q4_0( self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + elif self.qtype == "Q8_0": + x = F.linear(x, dequantize_blocks_Q8_0( + self.Q8_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) else: raise ValueError(f"Unknown qtype: {self.qtype}") @@ -139,7 +144,7 @@ def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]: block_size, type_size = GGML_QUANT_SIZES[qtype] if shape[-1] % block_size != 0: raise ValueError( - f"Quantized tensor row size ({shape[-1]}) is not a multiple of Q4_0 block size ({block_size})") + f"Quantized tensor row size ({shape[-1]}) is not a multiple of {qtype} block size ({block_size})") return (*shape[:-1], shape[-1] // block_size * type_size) @@ -148,17 +153,17 @@ def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]: block_size, type_size = GGML_QUANT_SIZES[qtype] if shape[-1] % type_size != 0: raise ValueError( - f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of Q4_0 type size ({type_size})") + f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {qtype} type size ({type_size})") return (*shape[:-1], shape[-1] // type_size * block_size) GGML_QUANT_SIZES = { "Q4_0": (32, 2 + 16), + "Q8_0": (32, 2 + 32), } - -def dequantize_blocks_Q4_0(data, dtype=torch.float16): - block_size, type_size = GGML_QUANT_SIZES["Q4_0"] +def dequantize_blocks_Q8_0(data, dtype=torch.float16): + block_size, type_size = GGML_QUANT_SIZES["Q8_0"] data = data.to(torch.uint8) shape = data.shape @@ -173,17 +178,14 @@ def dequantize_blocks_Q4_0(data, dtype=torch.float16): n_blocks = blocks.shape[0] d, qs = split_block_dims(blocks, 2) - d = d.view(torch.float16) + d = d.view(torch.float16).to(torch.float32) - qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( - [0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) - qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + qs = qs.view(torch.int8).to(torch.float32) out = (d * qs) out = out.reshape(quant_shape_from_byte_shape( shape, - qtype="Q4_0", + qtype="Q8_0", )).to(dtype) - return out - + return out \ No newline at end of file diff --git a/nodes.py b/nodes.py index 76edd06..2cd1b3d 100644 --- a/nodes.py +++ b/nodes.py @@ -49,7 +49,8 @@ class DownloadAndLoadMochiModel: [ "mochi_preview_dit_fp8_e4m3fn.safetensors", "mochi_preview_dit_bf16.safetensors", - "mochi_preview_dit_GGUF_Q4_0_v2.safetensors" + "mochi_preview_dit_GGUF_Q4_0_v2.safetensors", + "mochi_preview_dit_GGUF_Q8_0.safetensors", ], {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },