Fix FluxBlockLoraLoader

This commit is contained in:
kijai 2024-11-25 16:28:29 +02:00
parent 8f057eb563
commit 5920419f44

View File

@ -1981,7 +1981,7 @@ class FluxBlockLoraLoader:
keys_to_delete = []
for block in blocks:
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change
for key in list(loaded.keys()):
match = False
if isinstance(key, str) and block in key:
match = True
@ -1994,40 +1994,33 @@ class FluxBlockLoraLoader:
if match:
ratio = blocks[block]
if ratio == 0:
keys_to_delete.append(key) # Collect keys to delete
keys_to_delete.append(key)
else:
value = loaded[key]
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format
if len(value[1]) > 3:
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1]))
else:
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
inner_tuple = value[1]
if len(inner_tuple) >= 3:
inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:])
loaded[key] = (value[0], inner_tuple)
else:
# Handle the simpler format directly
loaded[key] = (value[0], ratio)
# Now perform the deletion of keys
for key in keys_to_delete:
del loaded[key]
print("loading lora keys:")
for key, value in loaded.items():
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format
if len(value[1]) > 2:
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
else:
alpha = value[1][-2] # Adjust according to the second format's structure
inner_tuple = value[1]
alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None
else:
# Handle the simpler format directly
alpha = value[1] if len(value) > 1 else None
print(f"Key: {key}, Alpha: {alpha}")
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
k = set(k)
for x in loaded:
@ -2196,41 +2189,46 @@ class CheckpointLoaderKJ:
.reshape(b, -1, heads * dim_head)
)
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None
# class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
# def reset_parameters(self):
# return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# def forward_comfy_cast_weights(self, input):
# weight, bias = cast_bias_weight(self, input)
# return torch.nn.functional.linear(input, weight, bias)
# def forward(self, *args, **kwargs):
# if self.comfy_cast_weights:
# return self.forward_comfy_cast_weights(*args, **kwargs)
# else:
# return super().forward(*args, **kwargs)
cublas_patched = False
if patch_cublaslinear:
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
return None
if not cublas_patched:
original_linear = disable_weight_init.Linear
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear
cublas_patched = True
else:
disable_weight_init.Linear = original_linear
cublas_patched = False
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear
else:
disable_weight_init.Linear = OriginalLinear
if sage_attention:
comfy_attention.optimized_attention = attention_sage
else:
@ -2238,7 +2236,6 @@ class CheckpointLoaderKJ:
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
return model, clip, vae
import comfy.model_patcher
@ -2509,8 +2506,8 @@ class TorchCompileLTXModel:
if not self._compiled:
try:
for i, block in enumerate(diffusion_model.transformer_blocks):
#print("Compiling double_block", i)
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend))
compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
self._compiled = True
compile_settings = {
"backend": backend,
@ -2519,8 +2516,9 @@ class TorchCompileLTXModel:
"dynamic": dynamic,
}
setattr(m.model, "compile_settings", compile_settings)
except:
raise RuntimeError("Failed to compile model")
raise RuntimeError("Failed to compile model")
return (m, )