From 4348d1ed2010203e3f07b630932a17a091d249f7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 27 Oct 2024 14:04:51 +0200 Subject: [PATCH] fix Q4 cublas ops --- mz_gguf_loader.py | 11 ++--------- nodes.py | 7 +++++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index f340f96..a01791c 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -135,20 +135,13 @@ class WQLinear_GGUF(nn.Module): @torch.no_grad() def forward(self, x): - # x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight)) if self.qtype == "Q4_0": - x = F.linear(x, dequantize_blocks_Q4_0( - self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype) elif self.qtype == "Q8_0": dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) - - #x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None) - x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) - else: raise ValueError(f"Unknown qtype: {self.qtype}") - - return x + return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) def split_block_dims(blocks, *args): diff --git a/nodes.py b/nodes.py index 0287e4b..f8abed5 100644 --- a/nodes.py +++ b/nodes.py @@ -169,6 +169,8 @@ class MochiModelLoader: "optional": { "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), + "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}), + }, } RETURN_TYPES = ("MOCHIMODEL",) @@ -176,7 +178,7 @@ class MochiModelLoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None): + def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -193,7 +195,8 @@ class MochiModelLoader: weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, - compile_args=compile_args + compile_args=compile_args, + cublas_ops=cublas_ops ) return (model, )