mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 12:54:40 +08:00
Fix FluxBlockLoraLoader
This commit is contained in:
parent
8f057eb563
commit
5920419f44
102
nodes/nodes.py
102
nodes/nodes.py
@ -1981,7 +1981,7 @@ class FluxBlockLoraLoader:
|
||||
keys_to_delete = []
|
||||
|
||||
for block in blocks:
|
||||
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change
|
||||
for key in list(loaded.keys()):
|
||||
match = False
|
||||
if isinstance(key, str) and block in key:
|
||||
match = True
|
||||
@ -1994,40 +1994,33 @@ class FluxBlockLoraLoader:
|
||||
if match:
|
||||
ratio = blocks[block]
|
||||
if ratio == 0:
|
||||
keys_to_delete.append(key) # Collect keys to delete
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
value = loaded[key]
|
||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
||||
# Handle the tuple format
|
||||
if len(value[1]) > 3:
|
||||
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
|
||||
else:
|
||||
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
|
||||
inner_tuple = value[1]
|
||||
if len(inner_tuple) >= 3:
|
||||
inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:])
|
||||
loaded[key] = (value[0], inner_tuple)
|
||||
else:
|
||||
# Handle the simpler format directly
|
||||
loaded[key] = (value[0], ratio)
|
||||
|
||||
# Now perform the deletion of keys
|
||||
for key in keys_to_delete:
|
||||
del loaded[key]
|
||||
|
||||
print("loading lora keys:")
|
||||
for key, value in loaded.items():
|
||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
||||
# Handle the tuple format
|
||||
if len(value[1]) > 2:
|
||||
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
|
||||
else:
|
||||
alpha = value[1][-2] # Adjust according to the second format's structure
|
||||
inner_tuple = value[1]
|
||||
alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None
|
||||
else:
|
||||
# Handle the simpler format directly
|
||||
alpha = value[1] if len(value) > 1 else None
|
||||
print(f"Key: {key}, Alpha: {alpha}")
|
||||
|
||||
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
|
||||
k = set(k)
|
||||
for x in loaded:
|
||||
@ -2196,41 +2189,46 @@ class CheckpointLoaderKJ:
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
# class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
|
||||
# def reset_parameters(self):
|
||||
# return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
# def forward_comfy_cast_weights(self, input):
|
||||
# weight, bias = cast_bias_weight(self, input)
|
||||
# return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
# def forward(self, *args, **kwargs):
|
||||
# if self.comfy_cast_weights:
|
||||
# return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
# else:
|
||||
# return super().forward(*args, **kwargs)
|
||||
cublas_patched = False
|
||||
if patch_cublaslinear:
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
except ImportError:
|
||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
if not cublas_patched:
|
||||
original_linear = disable_weight_init.Linear
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
except ImportError:
|
||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
disable_weight_init.Linear = PatchedLinear
|
||||
cublas_patched = True
|
||||
else:
|
||||
disable_weight_init.Linear = original_linear
|
||||
cublas_patched = False
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
disable_weight_init.Linear = PatchedLinear
|
||||
else:
|
||||
disable_weight_init.Linear = OriginalLinear
|
||||
if sage_attention:
|
||||
comfy_attention.optimized_attention = attention_sage
|
||||
else:
|
||||
@ -2238,7 +2236,6 @@ class CheckpointLoaderKJ:
|
||||
|
||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||
|
||||
|
||||
return model, clip, vae
|
||||
|
||||
import comfy.model_patcher
|
||||
@ -2509,8 +2506,8 @@ class TorchCompileLTXModel:
|
||||
if not self._compiled:
|
||||
try:
|
||||
for i, block in enumerate(diffusion_model.transformer_blocks):
|
||||
#print("Compiling double_block", i)
|
||||
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
||||
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
|
||||
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
|
||||
self._compiled = True
|
||||
compile_settings = {
|
||||
"backend": backend,
|
||||
@ -2519,8 +2516,9 @@ class TorchCompileLTXModel:
|
||||
"dynamic": dynamic,
|
||||
}
|
||||
setattr(m.model, "compile_settings", compile_settings)
|
||||
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user