mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-16 08:14:32 +08:00
Fix FluxBlockLoraLoader
This commit is contained in:
parent
8f057eb563
commit
5920419f44
102
nodes/nodes.py
102
nodes/nodes.py
@ -1981,7 +1981,7 @@ class FluxBlockLoraLoader:
|
|||||||
keys_to_delete = []
|
keys_to_delete = []
|
||||||
|
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change
|
for key in list(loaded.keys()):
|
||||||
match = False
|
match = False
|
||||||
if isinstance(key, str) and block in key:
|
if isinstance(key, str) and block in key:
|
||||||
match = True
|
match = True
|
||||||
@ -1994,40 +1994,33 @@ class FluxBlockLoraLoader:
|
|||||||
if match:
|
if match:
|
||||||
ratio = blocks[block]
|
ratio = blocks[block]
|
||||||
if ratio == 0:
|
if ratio == 0:
|
||||||
keys_to_delete.append(key) # Collect keys to delete
|
keys_to_delete.append(key)
|
||||||
else:
|
else:
|
||||||
value = loaded[key]
|
value = loaded[key]
|
||||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
||||||
# Handle the tuple format
|
inner_tuple = value[1]
|
||||||
if len(value[1]) > 3:
|
if len(inner_tuple) >= 3:
|
||||||
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
|
inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:])
|
||||||
else:
|
loaded[key] = (value[0], inner_tuple)
|
||||||
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
|
|
||||||
else:
|
else:
|
||||||
# Handle the simpler format directly
|
|
||||||
loaded[key] = (value[0], ratio)
|
loaded[key] = (value[0], ratio)
|
||||||
|
|
||||||
# Now perform the deletion of keys
|
|
||||||
for key in keys_to_delete:
|
for key in keys_to_delete:
|
||||||
del loaded[key]
|
del loaded[key]
|
||||||
|
|
||||||
print("loading lora keys:")
|
print("loading lora keys:")
|
||||||
for key, value in loaded.items():
|
for key, value in loaded.items():
|
||||||
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
|
||||||
# Handle the tuple format
|
inner_tuple = value[1]
|
||||||
if len(value[1]) > 2:
|
alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None
|
||||||
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
|
|
||||||
else:
|
|
||||||
alpha = value[1][-2] # Adjust according to the second format's structure
|
|
||||||
else:
|
else:
|
||||||
# Handle the simpler format directly
|
|
||||||
alpha = value[1] if len(value) > 1 else None
|
alpha = value[1] if len(value) > 1 else None
|
||||||
print(f"Key: {key}, Alpha: {alpha}")
|
print(f"Key: {key}, Alpha: {alpha}")
|
||||||
|
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
|
|
||||||
k = set(k)
|
k = set(k)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
@ -2196,41 +2189,46 @@ class CheckpointLoaderKJ:
|
|||||||
.reshape(b, -1, heads * dim_head)
|
.reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
|
|
||||||
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
|
# class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
# def reset_parameters(self):
|
||||||
return None
|
# return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
# def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
# weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
# return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if self.comfy_cast_weights:
|
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
|
# def forward(self, *args, **kwargs):
|
||||||
|
# if self.comfy_cast_weights:
|
||||||
|
# return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
# else:
|
||||||
|
# return super().forward(*args, **kwargs)
|
||||||
|
cublas_patched = False
|
||||||
if patch_cublaslinear:
|
if patch_cublaslinear:
|
||||||
try:
|
if not cublas_patched:
|
||||||
from cublas_ops import CublasLinear
|
original_linear = disable_weight_init.Linear
|
||||||
except ImportError:
|
try:
|
||||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
from cublas_ops import CublasLinear
|
||||||
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
except ImportError:
|
||||||
def reset_parameters(self):
|
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||||
return None
|
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
disable_weight_init.Linear = PatchedLinear
|
||||||
|
cublas_patched = True
|
||||||
|
else:
|
||||||
|
disable_weight_init.Linear = original_linear
|
||||||
|
cublas_patched = False
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
if self.comfy_cast_weights:
|
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
disable_weight_init.Linear = PatchedLinear
|
|
||||||
else:
|
|
||||||
disable_weight_init.Linear = OriginalLinear
|
|
||||||
if sage_attention:
|
if sage_attention:
|
||||||
comfy_attention.optimized_attention = attention_sage
|
comfy_attention.optimized_attention = attention_sage
|
||||||
else:
|
else:
|
||||||
@ -2238,7 +2236,6 @@ class CheckpointLoaderKJ:
|
|||||||
|
|
||||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||||
|
|
||||||
|
|
||||||
return model, clip, vae
|
return model, clip, vae
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
@ -2509,8 +2506,8 @@ class TorchCompileLTXModel:
|
|||||||
if not self._compiled:
|
if not self._compiled:
|
||||||
try:
|
try:
|
||||||
for i, block in enumerate(diffusion_model.transformer_blocks):
|
for i, block in enumerate(diffusion_model.transformer_blocks):
|
||||||
#print("Compiling double_block", i)
|
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
|
||||||
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
|
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
|
||||||
self._compiled = True
|
self._compiled = True
|
||||||
compile_settings = {
|
compile_settings = {
|
||||||
"backend": backend,
|
"backend": backend,
|
||||||
@ -2519,8 +2516,9 @@ class TorchCompileLTXModel:
|
|||||||
"dynamic": dynamic,
|
"dynamic": dynamic,
|
||||||
}
|
}
|
||||||
setattr(m.model, "compile_settings", compile_settings)
|
setattr(m.model, "compile_settings", compile_settings)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
raise RuntimeError("Failed to compile model")
|
raise RuntimeError("Failed to compile model")
|
||||||
|
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user