Update nodes.py

This commit is contained in:
kijai 2024-11-22 17:15:50 +02:00
parent 44620cb566
commit 8f057eb563

View File

@ -2172,11 +2172,6 @@ class CheckpointLoaderKJ:
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
from nodes import CheckpointLoaderSimple
if patch_cublaslinear:
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
if sage_attention:
from sageattention import sageattn
@ -2214,22 +2209,25 @@ class CheckpointLoaderKJ:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
if patch_cublaslinear:
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear
else:
disable_weight_init.Linear = OriginalLinear