Fix FluxBlockLoraLoader

This commit is contained in:
kijai 2024-11-25 16:28:29 +02:00
parent 8f057eb563
commit 5920419f44

View File

@ -1981,7 +1981,7 @@ class FluxBlockLoraLoader:
keys_to_delete = [] keys_to_delete = []
for block in blocks: for block in blocks:
for key in list(loaded.keys()): # Convert keys to a list to avoid runtime error due to size change for key in list(loaded.keys()):
match = False match = False
if isinstance(key, str) and block in key: if isinstance(key, str) and block in key:
match = True match = True
@ -1994,40 +1994,33 @@ class FluxBlockLoraLoader:
if match: if match:
ratio = blocks[block] ratio = blocks[block]
if ratio == 0: if ratio == 0:
keys_to_delete.append(key) # Collect keys to delete keys_to_delete.append(key)
else: else:
value = loaded[key] value = loaded[key]
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format inner_tuple = value[1]
if len(value[1]) > 3: if len(inner_tuple) >= 3:
loaded[key] = (value[0], value[1][:-3] + (ratio, value[1][-2], value[1][-1])) inner_tuple = (inner_tuple[0], inner_tuple[1], ratio, *inner_tuple[3:])
else: loaded[key] = (value[0], inner_tuple)
loaded[key] = (value[0], value[1][:-2] + (ratio, value[1][-1]))
else: else:
# Handle the simpler format directly
loaded[key] = (value[0], ratio) loaded[key] = (value[0], ratio)
# Now perform the deletion of keys
for key in keys_to_delete: for key in keys_to_delete:
del loaded[key] del loaded[key]
print("loading lora keys:") print("loading lora keys:")
for key, value in loaded.items(): for key, value in loaded.items():
if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple): if isinstance(value, tuple) and len(value) > 1 and isinstance(value[1], tuple):
# Handle the tuple format inner_tuple = value[1]
if len(value[1]) > 2: alpha = inner_tuple[2] if len(inner_tuple) >= 3 else None
alpha = value[1][-3] # Assuming the alpha value is the third last element in the tuple
else:
alpha = value[1][-2] # Adjust according to the second format's structure
else: else:
# Handle the simpler format directly
alpha = value[1] if len(value) > 1 else None alpha = value[1] if len(value) > 1 else None
print(f"Key: {key}, Alpha: {alpha}") print(f"Key: {key}, Alpha: {alpha}")
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model) k = new_modelpatcher.add_patches(loaded, strength_model)
k = set(k) k = set(k)
for x in loaded: for x in loaded:
@ -2196,41 +2189,46 @@ class CheckpointLoaderKJ:
.reshape(b, -1, heads * dim_head) .reshape(b, -1, heads * dim_head)
) )
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp): # class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self): # def reset_parameters(self):
return None # return None
def forward_comfy_cast_weights(self, input): # def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) # weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias) # return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# def forward(self, *args, **kwargs):
# if self.comfy_cast_weights:
# return self.forward_comfy_cast_weights(*args, **kwargs)
# else:
# return super().forward(*args, **kwargs)
cublas_patched = False
if patch_cublaslinear: if patch_cublaslinear:
try: if not cublas_patched:
from cublas_ops import CublasLinear original_linear = disable_weight_init.Linear
except ImportError: try:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") from cublas_ops import CublasLinear
class PatchedLinear(CublasLinear, CastWeightBiasOp): except ImportError:
def reset_parameters(self): raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
return None class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear
cublas_patched = True
else:
disable_weight_init.Linear = original_linear
cublas_patched = False
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
disable_weight_init.Linear = PatchedLinear
else:
disable_weight_init.Linear = OriginalLinear
if sage_attention: if sage_attention:
comfy_attention.optimized_attention = attention_sage comfy_attention.optimized_attention = attention_sage
else: else:
@ -2238,7 +2236,6 @@ class CheckpointLoaderKJ:
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
return model, clip, vae return model, clip, vae
import comfy.model_patcher import comfy.model_patcher
@ -2509,8 +2506,8 @@ class TorchCompileLTXModel:
if not self._compiled: if not self._compiled:
try: try:
for i, block in enumerate(diffusion_model.transformer_blocks): for i, block in enumerate(diffusion_model.transformer_blocks):
#print("Compiling double_block", i) compiled_block = torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)
m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", torch.compile(block, mode=mode, dynamic=dynamic, fullgraph=fullgraph, backend=backend)) m.add_object_patch(f"diffusion_model.transformer_blocks.{i}", compiled_block)
self._compiled = True self._compiled = True
compile_settings = { compile_settings = {
"backend": backend, "backend": backend,
@ -2519,8 +2516,9 @@ class TorchCompileLTXModel:
"dynamic": dynamic, "dynamic": dynamic,
} }
setattr(m.model, "compile_settings", compile_settings) setattr(m.model, "compile_settings", compile_settings)
except: except:
raise RuntimeError("Failed to compile model") raise RuntimeError("Failed to compile model")
return (m, ) return (m, )