diff --git a/nodes/nodes.py b/nodes/nodes.py index 68e79ab..df5e951 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -1981,7 +1981,7 @@ class FluxBlockLoraLoader: keys_to_delete = [] for block in blocks: - for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change + for key in list(loaded.keys()): match = False if isinstance(key, str) and block in key: match = True @@ -1994,40 +1994,33 @@ class FluxBlockLoraLoader: if match: ratio = blocks[block] if ratio == 0: - keys_to_delete.append(key) # Collect keys to delete + keys_to_delete.append(key) else: value = loaded[key] if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): - # Handle the tuple format - if len(value[1]) > 3: - loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1])) - else: - loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1])) + inner_tuple = value[1] + if len(inner_tuple) >= 3: + inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:]) + loaded[key] = (value[0], inner_tuple) else: - # Handle the simpler format directly loaded[key] = (value[0], ratio) - # Now perform the deletion of keys for key in keys_to_delete: del loaded[key] print("loading lora keys:") for key, value in loaded.items(): if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): - # Handle the tuple format - if len(value[1]) > 2: - alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple - else: - alpha = value[1][-2] # Adjust according to the second format's structure + inner_tuple = value[1] + alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None else: - # Handle the simpler format directly alpha = value[1] if len(value) > 1 else None print(f"Key: {key}, Alpha: {alpha}") - if model is not None: - new_modelpatcher = model.clone() - k = new_modelpatcher.add_patches(loaded, strength_model) + if model is not None: + new_modelpatcher = model.clone() + k = new_modelpatcher.add_patches(loaded, strength_model) k = set(k) for x in loaded: @@ -2196,41 +2189,46 @@ class CheckpointLoaderKJ: .reshape(b, -1, heads * dim_head) ) - class OriginalLinear(torch.nn.Linear, CastWeightBiasOp): - def reset_parameters(self): - return None + # class OriginalLinear(torch.nn.Linear, CastWeightBiasOp): + # def reset_parameters(self): + # return None - def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - if self.comfy_cast_weights: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) + # def forward_comfy_cast_weights(self, input): + # weight, bias = cast_bias_weight(self, input) + # return torch.nn.functional.linear(input, weight, bias) + # def forward(self, *args, **kwargs): + # if self.comfy_cast_weights: + # return self.forward_comfy_cast_weights(*args, **kwargs) + # else: + # return super().forward(*args, **kwargs) + cublas_patched = False if patch_cublaslinear: - try: - from cublas_ops import CublasLinear - except ImportError: - raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") - class PatchedLinear(CublasLinear, CastWeightBiasOp): - def reset_parameters(self): - return None + if not cublas_patched: + original_linear = disable_weight_init.Linear + try: + from cublas_ops import CublasLinear + except ImportError: + raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") + class PatchedLinear(CublasLinear, CastWeightBiasOp): + def reset_parameters(self): + return None - def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + disable_weight_init.Linear = PatchedLinear + cublas_patched = True + else: + disable_weight_init.Linear = original_linear + cublas_patched = False - def forward(self, *args, **kwargs): - if self.comfy_cast_weights: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) - disable_weight_init.Linear = PatchedLinear - else: - disable_weight_init.Linear = OriginalLinear if sage_attention: comfy_attention.optimized_attention = attention_sage else: @@ -2238,7 +2236,6 @@ class CheckpointLoaderKJ: model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) - return model, clip, vae import comfy.model_patcher @@ -2509,8 +2506,8 @@ class TorchCompileLTXModel: if not self._compiled: try: for i, block in enumerate(diffusion_model.transformer_blocks): - #print("Compiling double_block", i) - m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) + compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend) + m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block) self._compiled = True compile_settings = { "backend": backend, @@ -2519,8 +2516,9 @@ class TorchCompileLTXModel: "dynamic": dynamic, } setattr(m.model, "compile_settings", compile_settings) + except: - raise RuntimeError("Failed to compile model") + raise RuntimeError("Failed to compile model") return (m, )