fix Q4 cublas ops

This commit is contained in:
kijai 2024-10-27 14:04:51 +02:00
parent d1155ad305
commit 4348d1ed20
2 changed files with 7 additions and 11 deletions

View File

@ -135,20 +135,13 @@ class WQLinear_GGUF(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
if self.qtype == "Q4_0": if self.qtype == "Q4_0":
x = F.linear(x, dequantize_blocks_Q4_0( dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
elif self.qtype == "Q8_0": elif self.qtype == "Q8_0":
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
else: else:
raise ValueError(f"Unknown qtype: {self.qtype}") raise ValueError(f"Unknown qtype: {self.qtype}")
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
return x
def split_block_dims(blocks, *args): def split_block_dims(blocks, *args):

View File

@ -169,6 +169,8 @@ class MochiModelLoader:
"optional": { "optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
}, },
} }
RETURN_TYPES = ("MOCHIMODEL",) RETURN_TYPES = ("MOCHIMODEL",)
@ -176,7 +178,7 @@ class MochiModelLoader:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None): def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -193,7 +195,8 @@ class MochiModelLoader:
weight_dtype=dtype, weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode, attention_mode=attention_mode,
compile_args=compile_args compile_args=compile_args,
cublas_ops=cublas_ops
) )
return (model, ) return (model, )