diff --git a/nodes/nodes.py b/nodes/nodes.py index 7422db3..68e79ab 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2172,11 +2172,6 @@ class CheckpointLoaderKJ: def patch(self, ckpt_name, patch_cublaslinear, sage_attention): from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight from nodes import CheckpointLoaderSimple - if patch_cublaslinear: - try: - from cublas_ops import CublasLinear - except ImportError: - raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") if sage_attention: from sageattention import sageattn @@ -2214,22 +2209,25 @@ class CheckpointLoaderKJ: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - - class PatchedLinear(CublasLinear, CastWeightBiasOp): - def reset_parameters(self): - return None - - def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - if self.comfy_cast_weights: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) if patch_cublaslinear: + try: + from cublas_ops import CublasLinear + except ImportError: + raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") + class PatchedLinear(CublasLinear, CastWeightBiasOp): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) disable_weight_init.Linear = PatchedLinear else: disable_weight_init.Linear = OriginalLinear