fix Q4 cublas ops
This commit is contained in:
parent
d1155ad305
commit
4348d1ed20
@ -135,20 +135,13 @@ class WQLinear_GGUF(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
|
||||
if self.qtype == "Q4_0":
|
||||
x = F.linear(x, dequantize_blocks_Q4_0(
|
||||
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
|
||||
elif self.qtype == "Q8_0":
|
||||
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
||||
|
||||
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||
|
||||
return x
|
||||
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
|
||||
def split_block_dims(blocks, *args):
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -169,6 +169,8 @@ class MochiModelLoader:
|
||||
"optional": {
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MOCHIMODEL",)
|
||||
@ -176,7 +178,7 @@ class MochiModelLoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None):
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -193,7 +195,8 @@ class MochiModelLoader:
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode,
|
||||
compile_args=compile_args
|
||||
compile_args=compile_args,
|
||||
cublas_ops=cublas_ops
|
||||
)
|
||||
|
||||
return (model, )
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user