mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-15 15:54:38 +08:00
Update nodes.py
This commit is contained in:
parent
44620cb566
commit
8f057eb563
@ -2172,11 +2172,6 @@ class CheckpointLoaderKJ:
|
|||||||
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
||||||
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
||||||
from nodes import CheckpointLoaderSimple
|
from nodes import CheckpointLoaderSimple
|
||||||
if patch_cublaslinear:
|
|
||||||
try:
|
|
||||||
from cublas_ops import CublasLinear
|
|
||||||
except ImportError:
|
|
||||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
|
||||||
|
|
||||||
if sage_attention:
|
if sage_attention:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
@ -2214,22 +2209,25 @@ class CheckpointLoaderKJ:
|
|||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
|
||||||
def reset_parameters(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if self.comfy_cast_weights:
|
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
if patch_cublaslinear:
|
if patch_cublaslinear:
|
||||||
|
try:
|
||||||
|
from cublas_ops import CublasLinear
|
||||||
|
except ImportError:
|
||||||
|
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||||
|
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
disable_weight_init.Linear = PatchedLinear
|
disable_weight_init.Linear = PatchedLinear
|
||||||
else:
|
else:
|
||||||
disable_weight_init.Linear = OriginalLinear
|
disable_weight_init.Linear = OriginalLinear
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user