import logging import os import numpy as np import safetensors import torch import torch.utils.checkpoint from tqdm.auto import trange from PIL import Image, ImageDraw, ImageFont from typing_extensions import override import comfy.samplers import comfy.sd import comfy.utils import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): new_dict = {} for k, v in d.items(): newv = v if isinstance(v, dict): newv = make_batch_extra_option_dict(v, indicies, full_size=full_size) elif isinstance(v, torch.Tensor): if full_size is None or v.size(0) == full_size: newv = v[indicies] elif isinstance(v, (list, tuple)) and len(v) == full_size: newv = [v[i] for i in indicies] new_dict[k] = newv return new_dict def process_cond_list(d, prefix=""): if hasattr(d, "__iter__") and not hasattr(d, "items"): for index, item in enumerate(d): process_cond_list(item, f"{prefix}.{index}") return d elif hasattr(d, "items"): for k, v in list(d.items()): if isinstance(v, dict): process_cond_list(v, f"{prefix}.{k}") elif isinstance(v, torch.Tensor): d[k] = v.clone() elif isinstance(v, (list, tuple)): for index, item in enumerate(v): process_cond_list(item, f"{prefix}.{k}.{index}") return d class TrainSampler(comfy.samplers.Sampler): def __init__( self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None, ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback self.batch_size = batch_size self.total_steps = total_steps self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset def fwd_bwd( self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size, bwd=True, ): xt = model_wrap.inner_model.model_sampling.noise_scaling( batch_sigmas, batch_noise, batch_latent, False ) x0 = model_wrap.inner_model.model_sampling.noise_scaling( torch.zeros_like(batch_sigmas), torch.zeros_like(batch_noise), batch_latent, False, ) model_wrap.conds["positive"] = [cond[i] for i in indicies] batch_extra_args = make_batch_extra_option_dict( extra_args, indicies, full_size=dataset_size ) with torch.autocast(xt.device.type, dtype=self.training_dtype): x0_pred = model_wrap( xt.requires_grad_(True), batch_sigmas.requires_grad_(True), **batch_extra_args, ) loss = self.loss_fn(x0_pred, x0) if bwd: bwd_loss = loss / self.grad_acc bwd_loss.backward() return loss def sample( self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False, ): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() ui_pbar = ProgressBar(self.total_steps) for i in ( pbar := trange( self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED, ) ): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( self.seed + i * 1000 ) indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() if self.real_dataset is None: batch_latent = torch.stack([latent_image[i] for i in indicies]) batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( batch_latent.device ) batch_sigmas = [ model_wrap.inner_model.model_sampling.percent_to_sigma( torch.rand((1,)).item() ) for _ in range(min(self.batch_size, dataset_size)) ] batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) loss = self.fwd_bwd( model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size, bwd=True, ) if self.loss_callback: self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) else: total_loss = 0 for index in indicies: single_latent = self.real_dataset[index].to(latent_image) batch_noise = noisegen.generate_noise( {"samples": single_latent} ).to(single_latent.device) batch_sigmas = ( model_wrap.inner_model.model_sampling.percent_to_sigma( torch.rand((1,)).item() ) ) batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) loss = self.fwd_bwd( model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size, bwd=False, ) total_loss += loss total_loss = total_loss / self.grad_acc / len(indicies) total_loss.backward() if self.loss_callback: self.loss_callback(total_loss.item()) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) class BiasDiff(torch.nn.Module): def __init__(self, bias): super().__init__() self.bias = bias def __call__(self, b): org_dtype = b.dtype return (b.to(self.bias) + self.bias).to(org_dtype) def passive_memory_usage(self): return self.bias.nelement() * self.bias.element_size() def move_to(self, device): self.to(device=device) return self.passive_memory_usage() def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") draw = ImageDraw.Draw(img) min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] prev_point = (0, height - int(scaled_loss[0] * height)) for i, l in enumerate(scaled_loss[1:], start=1): x = int(i / (steps - 1) * width) y = height - int(l * height) draw.line([prev_point, (x, y)], fill="blue", width=2) prev_point = (x, y) return img def find_all_highest_child_module_with_forward( model: torch.nn.Module, result=None, name=None ): if result is None: result = [] elif hasattr(model, "forward") and not isinstance( model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result name = name or "root" for next_name, child in model.named_children(): find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") return result def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward def fwd(args, kwargs): return org_forward(*args, **kwargs) def checkpointing_fwd(*args, **kwargs): return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) m.org_forward = org_forward m.forward = checkpointing_fwd def unpatch(m): if hasattr(m, "org_forward"): m.forward = m.org_forward del m.org_forward class TrainLoraNode(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="TrainLoraNode", display_name="Train LoRA", category="training", is_experimental=True, is_input_list=True, # All inputs become lists inputs=[ io.Model.Input("model", tooltip="The model to train the LoRA on."), io.Latent.Input( "latents", tooltip="The Latents to use for training, serve as dataset/input of the model.", ), io.Conditioning.Input( "positive", tooltip="The positive conditioning to use for training." ), io.Int.Input( "batch_size", default=1, min=1, max=10000, tooltip="The batch size to use for training.", ), io.Int.Input( "grad_accumulation_steps", default=1, min=1, max=1024, tooltip="The number of gradient accumulation steps to use for training.", ), io.Int.Input( "steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for.", ), io.Float.Input( "learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.0000001, tooltip="The learning rate to use for training.", ), io.Int.Input( "rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers.", ), io.Combo.Input( "optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training.", ), io.Combo.Input( "loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training.", ), io.Int.Input( "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), io.Combo.Input( "training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training.", ), io.Combo.Input( "lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora.", ), io.Combo.Input( "algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training.", ), io.Boolean.Input( "gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training.", ), io.Combo.Input( "existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), ], outputs=[ io.Model.Output( display_name="model", tooltip="Model with LoRA applied" ), io.Custom("LORA_MODEL").Output( display_name="lora", tooltip="LoRA weights" ), io.Custom("LOSS_MAP").Output( display_name="loss_map", tooltip="Loss history" ), io.Int.Output(display_name="steps", tooltip="Total training steps"), ], ) @classmethod def execute( cls, model, latents, positive, batch_size, steps, grad_accumulation_steps, learning_rate, rank, optimizer, loss_function, seed, training_dtype, lora_dtype, algorithm, gradient_checkpointing, existing_lora, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] batch_size = batch_size[0] steps = steps[0] grad_accumulation_steps = grad_accumulation_steps[0] learning_rate = learning_rate[0] rank = rank[0] optimizer = optimizer[0] loss_function = loss_function[0] seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] # Handle latents - either single dict or list of dicts if len(latents) == 1: latents = latents[0]["samples"] # Single latent dict else: latent_list = [] for latent in latents: latent = latent["samples"] bs = latent.shape[0] if bs != 1: for sub_latent in latent: latent_list.append(sub_latent[None]) else: latent_list.append(latent) latents = latent_list # Handle conditioning - either single list or list of lists if len(positive) == 1: positive = positive[0] # Single conditioning list else: # Multiple conditioning lists - flatten flat_positive = [] for cond in positive: if isinstance(cond, list): flat_positive.extend(cond) else: flat_positive.append(cond) positive = flat_positive mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) # latents here can be list of different size latent or one large batch if isinstance(latents, list): all_shapes = set() latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) logging.info(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: multi_res = False latents = torch.cat(latents, dim=0) num_images = len(latents) elif isinstance(latents, torch.Tensor): latents = latents.to(dtype) num_images = latents.shape[0] else: logging.error(f"Invalid latents type: {type(latents)}") logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images elif len(positive) != num_images: raise ValueError( f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." ) with torch.inference_mode(False): lora_sd = {} generator = torch.Generator() generator.manual_seed(seed) # Load existing LoRA weights if provided existing_weights = {} existing_steps = 0 if existing_lora != "[None]": lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) # Extract steps from filename like "trained_lora_10_steps_20250225_203716" existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) if lora_path: existing_weights = comfy.utils.load_torch_file(lora_path) all_weight_adapters = [] for n, m in mp.model.named_modules(): if hasattr(m, "weight_function"): if m.weight is not None: key = "{}.weight".format(n) shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale ) if existing_adapter is not None: break else: existing_adapter = None adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: train_adapter = existing_adapter.to_train().to( lora_dtype ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( m.weight, rank=rank, alpha=1.0 ).to(lora_dtype) for name, parameter in train_adapter.named_parameters(): lora_sd[f"{n}.{name}"] = parameter mp.add_weight_wrapper(key, train_adapter) all_weight_adapters.append(train_adapter) else: diff = torch.nn.Parameter( torch.zeros( m.weight.shape, dtype=lora_dtype, requires_grad=True ) ) diff_module = BiasDiff(diff) mp.add_weight_wrapper(key, BiasDiff(diff)) all_weight_adapters.append(diff_module) lora_sd["{}.diff".format(n)] = diff if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( torch.zeros( m.bias.shape, dtype=lora_dtype, requires_grad=True ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias mp.add_weight_wrapper(key, BiasDiff(bias)) all_weight_adapters.append(bias_module) if optimizer == "Adam": optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) elif optimizer == "AdamW": optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) elif optimizer == "SGD": optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) elif optimizer == "RMSprop": optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) # Setup loss function based on selection if loss_function == "MSE": criterion = torch.nn.MSELoss() elif loss_function == "L1": criterion = torch.nn.L1Loss() elif loss_function == "Huber": criterion = torch.nn.HuberLoss() elif loss_function == "SmoothL1": criterion = torch.nn.SmoothL1Loss() # setup models if gradient_checkpointing: for m in find_all_highest_child_module_with_forward( mp.model.diffusion_model ): patch(m) mp.model.requires_grad_(False) comfy.model_management.load_models_gpu( [mp], memory_required=1e20, force_full_load=True ) # Setup sampler and guider like in test script loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, total_steps=steps * grad_accumulation_steps, seed=seed, training_dtype=dtype, real_dataset=latents if multi_res else None, ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input # Training loop try: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if multi_res: # use first latent as dummy latent if multi_res latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, seed=noise.seed, ) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer for adapter in all_weight_adapters: adapter.requires_grad_(False) for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) class LoraModelLoader(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="LoraModelLoader", display_name="Load LoRA Model", category="loaders", is_experimental=True, inputs=[ io.Model.Input( "model", tooltip="The diffusion model the LoRA will be applied to." ), io.Custom("LORA_MODEL").Input( "lora", tooltip="The LoRA model to apply to the diffusion model." ), io.Float.Input( "strength_model", default=1.0, min=-100.0, max=100.0, tooltip="How strongly to modify the diffusion model. This value can be negative.", ), ], outputs=[ io.Model.Output( display_name="model", tooltip="The modified diffusion model." ), ], ) @classmethod def execute(cls, model, lora, strength_model): if strength_model == 0: return io.NodeOutput(model) model_lora, _ = comfy.sd.load_lora_for_models( model, None, lora, strength_model, 0 ) return io.NodeOutput(model_lora) class SaveLoRA(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SaveLoRA", display_name="Save LoRA Weights", category="loaders", is_experimental=True, is_output_node=True, inputs=[ io.Custom("LORA_MODEL").Input( "lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers.", ), io.String.Input( "prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file.", ), io.Int.Input( "steps", optional=True, tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", ), ], outputs=[], ) @classmethod def execute(cls, lora, prefix, steps=None): output_dir = folder_paths.get_output_directory() full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path(prefix, output_dir) ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) return io.NodeOutput() class LossGraphNode(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="LossGraphNode", display_name="Plot Loss Graph", category="training", is_experimental=True, is_output_node=True, inputs=[ io.Custom("LOSS_MAP").Input( "loss", tooltip="Loss map from training node." ), io.String.Input( "filename_prefix", default="loss_graph", tooltip="Prefix for the saved loss graph image.", ), ], outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], ) @classmethod def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 img = Image.new( "RGB", (width + margin, height + margin), "white" ) # Extend canvas draw = ImageDraw.Draw(img) min_loss, max_loss = min(loss_values), max(loss_values) scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] steps = len(loss_values) prev_point = (margin, height - int(scaled_loss[0] * height)) for i, l in enumerate(scaled_loss[1:], start=1): x = margin + int(i / steps * width) # Scale X properly y = height - int(l * height) draw.line([prev_point, (x, y)], fill="blue", width=2) prev_point = (x, y) draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis draw.line( [(margin, height), (width + margin, height)], fill="black", width=2 ) # X-axis font = None try: font = ImageFont.truetype("arial.ttf", 12) except IOError: font = ImageFont.load_default() # Add axis labels draw.text((5, height // 2), "Loss", font=font, fill="black") draw.text((width // 2, height + 10), "Steps", font=font, fill="black") # Add min/max loss values draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") draw.text( (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) # Convert PIL image to tensor for PreviewImage img_array = np.array(img).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] # Return preview UI return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) # ========== Extension Setup ========== class TrainingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TrainLoraNode, LoraModelLoader, SaveLoRA, LossGraphNode, ] async def comfy_entrypoint() -> TrainingExtension: return TrainingExtension()