fix Q4 cublas ops

This commit is contained in:
kijai 2024-10-27 14:04:51 +02:00
parent d1155ad305
commit 4348d1ed20
2 changed files with 7 additions and 11 deletions

View File

@ -135,20 +135,13 @@ class WQLinear_GGUF(nn.Module):
@torch.no_grad()
def forward(self, x):
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
if self.qtype == "Q4_0":
x = F.linear(x, dequantize_blocks_Q4_0(
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
elif self.qtype == "Q8_0":
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
else:
raise ValueError(f"Unknown qtype: {self.qtype}")
return x
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
def split_block_dims(blocks, *args):

View File

@ -169,6 +169,8 @@ class MochiModelLoader:
"optional": {
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
},
}
RETURN_TYPES = ("MOCHIMODEL",)
@ -176,7 +178,7 @@ class MochiModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None):
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -193,7 +195,8 @@ class MochiModelLoader:
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode,
compile_args=compile_args
compile_args=compile_args,
cublas_ops=cublas_ops
)
return (model, )