From ea5ee0b017c7fbfade7a7eb36e05523fc16e14de Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 10 Nov 2024 18:13:44 +0200 Subject: [PATCH] GGUF Q4 works --- model_loading.py | 13 +++++++----- mz_gguf_loader.py | 51 ++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/model_loading.py b/model_loading.py index d89a268..08218ca 100644 --- a/model_loading.py +++ b/model_loading.py @@ -206,7 +206,7 @@ class DownloadAndLoadCogVideoModel: if fp8_transformer == "enabled" or fp8_transformer == "fastmode": params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"} if "1.5" in model: - params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm","ofs_embedding", "norm_final", "norm_out", "proj_out"} + params_to_keep.update({"norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"}) for name, param in transformer.named_parameters(): if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(torch.float8_e4m3fn) @@ -214,7 +214,7 @@ class DownloadAndLoadCogVideoModel: if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear if "1.5" in model: - params_to_keep.update({"ff"}) + params_to_keep.update({"ff"}) #otherwise NaNs convert_fp8_linear(transformer, dtype, params_to_keep=params_to_keep) with open(scheduler_path) as f: @@ -422,11 +422,11 @@ class DownloadAndLoadCogVideoGGUFModel: params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"} cast_dtype = torch.float16 elif "1_5" in model: - params_to_keep = {"patch_embed", "time_embedding", "ofs_embedding", "norm_final", "norm_out", "proj_out", "norm"} + params_to_keep = {"norm1.linear.weight", "patch_embed", "time_embedding", "ofs_embedding", "norm_final", "norm_out", "proj_out"} cast_dtype = torch.bfloat16 for name, param in transformer.named_parameters(): if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(torch.bfloat16) + param.data = param.data.to(torch.float8_e4m3fn) else: param.data = param.data.to(cast_dtype) #for name, param in transformer.named_parameters(): @@ -438,8 +438,11 @@ class DownloadAndLoadCogVideoGGUFModel: transformer.attention_mode = attention_mode if fp8_fastmode: + params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"} + if "1.5" in model: + params_to_keep.update({"ff","norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"}) from .fp8_optimization import convert_fp8_linear - convert_fp8_linear(transformer, vae_dtype) + convert_fp8_linear(transformer, vae_dtype, params_to_keep=params_to_keep) if compile == "torch": # compilation diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index f5a6059..fd8c640 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -19,17 +19,21 @@ class quantize_lazy_load(): def quantize_load_state_dict(model, state_dict, device="cpu"): - Q4_0_qkey = [] + quant_keys = [] for key in state_dict.keys(): if key.endswith(".Q4_0_qweight"): - Q4_0_qkey.append(key.replace(".Q4_0_qweight", "")) + quant_keys.append(key.replace(".Q4_0_qweight", "")) + qtype = "Q4_0" + elif key.endswith(".Q8_0_qweight"): + quant_keys.append(key.replace(".Q8_0_qweight", "")) + qtype = "Q8_0" for name, module in model.named_modules(): - if name in Q4_0_qkey: + if name in quant_keys: q_linear = WQLinear_GGUF.from_linear( linear=module, device=device, - qtype="Q4_0", + qtype=qtype, ) set_op_by_name(model, name, q_linear) @@ -117,14 +121,14 @@ class WQLinear_GGUF(nn.Module): @torch.no_grad() def forward(self, x): - # x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight)) if self.qtype == "Q4_0": - x = F.linear(x, dequantize_blocks_Q4_0( - self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype) + elif self.qtype == "Q8_0": + dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) else: raise ValueError(f"Unknown qtype: {self.qtype}") - - return x + + return F.linear(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) def split_block_dims(blocks, *args): @@ -153,6 +157,7 @@ def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]: GGML_QUANT_SIZES = { "Q4_0": (32, 2 + 16), + "Q8_0": (32, 2 + 32), } @@ -186,3 +191,31 @@ def dequantize_blocks_Q4_0(data, dtype=torch.float16): )).to(dtype) return out +def dequantize_blocks_Q8_0(data, dtype=torch.float16): + block_size, type_size = GGML_QUANT_SIZES["Q8_0"] + + data = data.to(torch.uint8) + shape = data.shape + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = data.reshape((n_blocks, type_size)) + + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(torch.float32) + + qs = qs.view(torch.int8).to(torch.float32) + + out = (d * qs) + + out = out.reshape(quant_shape_from_byte_shape( + shape, + qtype="Q8_0", + )).to(dtype) + return out +