fix Q4 cublas ops
This commit is contained in:
parent
d1155ad305
commit
4348d1ed20
@ -135,20 +135,13 @@ class WQLinear_GGUF(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
|
|
||||||
if self.qtype == "Q4_0":
|
if self.qtype == "Q4_0":
|
||||||
x = F.linear(x, dequantize_blocks_Q4_0(
|
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
|
||||||
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
|
||||||
elif self.qtype == "Q8_0":
|
elif self.qtype == "Q8_0":
|
||||||
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
||||||
|
|
||||||
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
|
|
||||||
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||||
|
return self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def split_block_dims(blocks, *args):
|
def split_block_dims(blocks, *args):
|
||||||
|
|||||||
7
nodes.py
7
nodes.py
@ -169,6 +169,8 @@ class MochiModelLoader:
|
|||||||
"optional": {
|
"optional": {
|
||||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
|
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||||
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("MOCHIMODEL",)
|
RETURN_TYPES = ("MOCHIMODEL",)
|
||||||
@ -176,7 +178,7 @@ class MochiModelLoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None):
|
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -193,7 +195,8 @@ class MochiModelLoader:
|
|||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
compile_args=compile_args
|
compile_args=compile_args,
|
||||||
|
cublas_ops=cublas_ops
|
||||||
)
|
)
|
||||||
|
|
||||||
return (model, )
|
return (model, )
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user