mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
823 lines
29 KiB
Python
823 lines
29 KiB
Python
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
import safetensors
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from tqdm.auto import trange
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from typing_extensions import override
|
|
|
|
import comfy.samplers
|
|
import comfy.sd
|
|
import comfy.utils
|
|
import comfy.model_management
|
|
import comfy_extras.nodes_custom_sampler
|
|
import folder_paths
|
|
import node_helpers
|
|
from comfy.weight_adapter import adapters, adapter_maps
|
|
from comfy_api.latest import ComfyExtension, io, ui
|
|
from comfy.utils import ProgressBar
|
|
|
|
|
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|
new_dict = {}
|
|
for k, v in d.items():
|
|
newv = v
|
|
if isinstance(v, dict):
|
|
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
|
elif isinstance(v, torch.Tensor):
|
|
if full_size is None or v.size(0) == full_size:
|
|
newv = v[indicies]
|
|
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
|
newv = [v[i] for i in indicies]
|
|
new_dict[k] = newv
|
|
return new_dict
|
|
|
|
|
|
def process_cond_list(d, prefix=""):
|
|
if hasattr(d, "__iter__") and not hasattr(d, "items"):
|
|
for index, item in enumerate(d):
|
|
process_cond_list(item, f"{prefix}.{index}")
|
|
return d
|
|
elif hasattr(d, "items"):
|
|
for k, v in list(d.items()):
|
|
if isinstance(v, dict):
|
|
process_cond_list(v, f"{prefix}.{k}")
|
|
elif isinstance(v, torch.Tensor):
|
|
d[k] = v.clone()
|
|
elif isinstance(v, (list, tuple)):
|
|
for index, item in enumerate(v):
|
|
process_cond_list(item, f"{prefix}.{k}.{index}")
|
|
return d
|
|
|
|
|
|
class TrainSampler(comfy.samplers.Sampler):
|
|
def __init__(
|
|
self,
|
|
loss_fn,
|
|
optimizer,
|
|
loss_callback=None,
|
|
batch_size=1,
|
|
grad_acc=1,
|
|
total_steps=1,
|
|
seed=0,
|
|
training_dtype=torch.bfloat16,
|
|
real_dataset=None,
|
|
):
|
|
self.loss_fn = loss_fn
|
|
self.optimizer = optimizer
|
|
self.loss_callback = loss_callback
|
|
self.batch_size = batch_size
|
|
self.total_steps = total_steps
|
|
self.grad_acc = grad_acc
|
|
self.seed = seed
|
|
self.training_dtype = training_dtype
|
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
|
|
|
def fwd_bwd(
|
|
self,
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
batch_latent,
|
|
cond,
|
|
indicies,
|
|
extra_args,
|
|
dataset_size,
|
|
bwd=True,
|
|
):
|
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
|
batch_sigmas, batch_noise, batch_latent, False
|
|
)
|
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
|
torch.zeros_like(batch_sigmas),
|
|
torch.zeros_like(batch_noise),
|
|
batch_latent,
|
|
False,
|
|
)
|
|
|
|
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
|
batch_extra_args = make_batch_extra_option_dict(
|
|
extra_args, indicies, full_size=dataset_size
|
|
)
|
|
|
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
|
x0_pred = model_wrap(
|
|
xt.requires_grad_(True),
|
|
batch_sigmas.requires_grad_(True),
|
|
**batch_extra_args,
|
|
)
|
|
loss = self.loss_fn(x0_pred, x0)
|
|
if bwd:
|
|
bwd_loss = loss / self.grad_acc
|
|
bwd_loss.backward()
|
|
return loss
|
|
|
|
def sample(
|
|
self,
|
|
model_wrap,
|
|
sigmas,
|
|
extra_args,
|
|
callback,
|
|
noise,
|
|
latent_image=None,
|
|
denoise_mask=None,
|
|
disable_pbar=False,
|
|
):
|
|
model_wrap.conds = process_cond_list(model_wrap.conds)
|
|
cond = model_wrap.conds["positive"]
|
|
dataset_size = sigmas.size(0)
|
|
torch.cuda.empty_cache()
|
|
ui_pbar = ProgressBar(self.total_steps)
|
|
for i in (
|
|
pbar := trange(
|
|
self.total_steps,
|
|
desc="Training LoRA",
|
|
smoothing=0.01,
|
|
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
|
|
)
|
|
):
|
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
|
self.seed + i * 1000
|
|
)
|
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
|
|
if self.real_dataset is None:
|
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
batch_latent.device
|
|
)
|
|
batch_sigmas = [
|
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
torch.rand((1,)).item()
|
|
)
|
|
for _ in range(min(self.batch_size, dataset_size))
|
|
]
|
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
|
|
|
loss = self.fwd_bwd(
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
batch_latent,
|
|
cond,
|
|
indicies,
|
|
extra_args,
|
|
dataset_size,
|
|
bwd=True,
|
|
)
|
|
if self.loss_callback:
|
|
self.loss_callback(loss.item())
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
else:
|
|
total_loss = 0
|
|
for index in indicies:
|
|
single_latent = self.real_dataset[index].to(latent_image)
|
|
batch_noise = noisegen.generate_noise(
|
|
{"samples": single_latent}
|
|
).to(single_latent.device)
|
|
batch_sigmas = (
|
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
torch.rand((1,)).item()
|
|
)
|
|
)
|
|
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
|
loss = self.fwd_bwd(
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
single_latent,
|
|
cond,
|
|
[index],
|
|
extra_args,
|
|
dataset_size,
|
|
bwd=False,
|
|
)
|
|
total_loss += loss
|
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
|
total_loss.backward()
|
|
if self.loss_callback:
|
|
self.loss_callback(total_loss.item())
|
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
|
|
|
if (i + 1) % self.grad_acc == 0:
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
ui_pbar.update(1)
|
|
torch.cuda.empty_cache()
|
|
return torch.zeros_like(latent_image)
|
|
|
|
|
|
class BiasDiff(torch.nn.Module):
|
|
def __init__(self, bias):
|
|
super().__init__()
|
|
self.bias = bias
|
|
|
|
def __call__(self, b):
|
|
org_dtype = b.dtype
|
|
return (b.to(self.bias) + self.bias).to(org_dtype)
|
|
|
|
def passive_memory_usage(self):
|
|
return self.bias.nelement() * self.bias.element_size()
|
|
|
|
def move_to(self, device):
|
|
self.to(device=device)
|
|
return self.passive_memory_usage()
|
|
|
|
|
|
def draw_loss_graph(loss_map, steps):
|
|
width, height = 500, 300
|
|
img = Image.new("RGB", (width, height), "white")
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
|
|
|
prev_point = (0, height - int(scaled_loss[0] * height))
|
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
|
x = int(i / (steps - 1) * width)
|
|
y = height - int(l * height)
|
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
|
prev_point = (x, y)
|
|
|
|
return img
|
|
|
|
|
|
def find_all_highest_child_module_with_forward(
|
|
model: torch.nn.Module, result=None, name=None
|
|
):
|
|
if result is None:
|
|
result = []
|
|
elif hasattr(model, "forward") and not isinstance(
|
|
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
|
|
):
|
|
result.append(model)
|
|
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
|
return result
|
|
name = name or "root"
|
|
for next_name, child in model.named_children():
|
|
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
|
return result
|
|
|
|
|
|
def patch(m):
|
|
if not hasattr(m, "forward"):
|
|
return
|
|
org_forward = m.forward
|
|
|
|
def fwd(args, kwargs):
|
|
return org_forward(*args, **kwargs)
|
|
|
|
def checkpointing_fwd(*args, **kwargs):
|
|
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
|
|
|
m.org_forward = org_forward
|
|
m.forward = checkpointing_fwd
|
|
|
|
|
|
def unpatch(m):
|
|
if hasattr(m, "org_forward"):
|
|
m.forward = m.org_forward
|
|
del m.org_forward
|
|
|
|
|
|
class TrainLoraNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="TrainLoraNode",
|
|
display_name="Train LoRA",
|
|
category="training",
|
|
is_experimental=True,
|
|
is_input_list=True, # All inputs become lists
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
|
io.Latent.Input(
|
|
"latents",
|
|
tooltip="The Latents to use for training, serve as dataset/input of the model.",
|
|
),
|
|
io.Conditioning.Input(
|
|
"positive", tooltip="The positive conditioning to use for training."
|
|
),
|
|
io.Int.Input(
|
|
"batch_size",
|
|
default=1,
|
|
min=1,
|
|
max=10000,
|
|
tooltip="The batch size to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"grad_accumulation_steps",
|
|
default=1,
|
|
min=1,
|
|
max=1024,
|
|
tooltip="The number of gradient accumulation steps to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"steps",
|
|
default=16,
|
|
min=1,
|
|
max=100000,
|
|
tooltip="The number of steps to train the LoRA for.",
|
|
),
|
|
io.Float.Input(
|
|
"learning_rate",
|
|
default=0.0005,
|
|
min=0.0000001,
|
|
max=1.0,
|
|
step=0.0000001,
|
|
tooltip="The learning rate to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"rank",
|
|
default=8,
|
|
min=1,
|
|
max=128,
|
|
tooltip="The rank of the LoRA layers.",
|
|
),
|
|
io.Combo.Input(
|
|
"optimizer",
|
|
options=["AdamW", "Adam", "SGD", "RMSprop"],
|
|
default="AdamW",
|
|
tooltip="The optimizer to use for training.",
|
|
),
|
|
io.Combo.Input(
|
|
"loss_function",
|
|
options=["MSE", "L1", "Huber", "SmoothL1"],
|
|
default="MSE",
|
|
tooltip="The loss function to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"seed",
|
|
default=0,
|
|
min=0,
|
|
max=0xFFFFFFFFFFFFFFFF,
|
|
tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
|
),
|
|
io.Combo.Input(
|
|
"training_dtype",
|
|
options=["bf16", "fp32"],
|
|
default="bf16",
|
|
tooltip="The dtype to use for training.",
|
|
),
|
|
io.Combo.Input(
|
|
"lora_dtype",
|
|
options=["bf16", "fp32"],
|
|
default="bf16",
|
|
tooltip="The dtype to use for lora.",
|
|
),
|
|
io.Combo.Input(
|
|
"algorithm",
|
|
options=list(adapter_maps.keys()),
|
|
default=list(adapter_maps.keys())[0],
|
|
tooltip="The algorithm to use for training.",
|
|
),
|
|
io.Boolean.Input(
|
|
"gradient_checkpointing",
|
|
default=True,
|
|
tooltip="Use gradient checkpointing for training.",
|
|
),
|
|
io.Combo.Input(
|
|
"existing_lora",
|
|
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
|
default="[None]",
|
|
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(
|
|
display_name="model", tooltip="Model with LoRA applied"
|
|
),
|
|
io.Custom("LORA_MODEL").Output(
|
|
display_name="lora", tooltip="LoRA weights"
|
|
),
|
|
io.Custom("LOSS_MAP").Output(
|
|
display_name="loss_map", tooltip="Loss history"
|
|
),
|
|
io.Int.Output(display_name="steps", tooltip="Total training steps"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
model,
|
|
latents,
|
|
positive,
|
|
batch_size,
|
|
steps,
|
|
grad_accumulation_steps,
|
|
learning_rate,
|
|
rank,
|
|
optimizer,
|
|
loss_function,
|
|
seed,
|
|
training_dtype,
|
|
lora_dtype,
|
|
algorithm,
|
|
gradient_checkpointing,
|
|
existing_lora,
|
|
):
|
|
# Extract scalars from lists (due to is_input_list=True)
|
|
model = model[0]
|
|
batch_size = batch_size[0]
|
|
steps = steps[0]
|
|
grad_accumulation_steps = grad_accumulation_steps[0]
|
|
learning_rate = learning_rate[0]
|
|
rank = rank[0]
|
|
optimizer = optimizer[0]
|
|
loss_function = loss_function[0]
|
|
seed = seed[0]
|
|
training_dtype = training_dtype[0]
|
|
lora_dtype = lora_dtype[0]
|
|
algorithm = algorithm[0]
|
|
gradient_checkpointing = gradient_checkpointing[0]
|
|
existing_lora = existing_lora[0]
|
|
|
|
# Handle latents - either single dict or list of dicts
|
|
if len(latents) == 1:
|
|
latents = latents[0]["samples"] # Single latent dict
|
|
else:
|
|
latent_list = []
|
|
for latent in latents:
|
|
latent = latent["samples"]
|
|
bs = latent.shape[0]
|
|
if bs != 1:
|
|
for sub_latent in latent:
|
|
latent_list.append(sub_latent[None])
|
|
else:
|
|
latent_list.append(latent)
|
|
latents = latent_list
|
|
|
|
# Handle conditioning - either single list or list of lists
|
|
if len(positive) == 1:
|
|
positive = positive[0] # Single conditioning list
|
|
else:
|
|
# Multiple conditioning lists - flatten
|
|
flat_positive = []
|
|
for cond in positive:
|
|
if isinstance(cond, list):
|
|
flat_positive.extend(cond)
|
|
else:
|
|
flat_positive.append(cond)
|
|
positive = flat_positive
|
|
|
|
mp = model.clone()
|
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
|
mp.set_model_compute_dtype(dtype)
|
|
|
|
# latents here can be list of different size latent or one large batch
|
|
if isinstance(latents, list):
|
|
all_shapes = set()
|
|
latents = [t.to(dtype) for t in latents]
|
|
for latent in latents:
|
|
all_shapes.add(latent.shape)
|
|
logging.info(f"Latent shapes: {all_shapes}")
|
|
if len(all_shapes) > 1:
|
|
multi_res = True
|
|
else:
|
|
multi_res = False
|
|
latents = torch.cat(latents, dim=0)
|
|
num_images = len(latents)
|
|
elif isinstance(latents, torch.Tensor):
|
|
latents = latents.to(dtype)
|
|
num_images = latents.shape[0]
|
|
else:
|
|
logging.error(f"Invalid latents type: {type(latents)}")
|
|
|
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
|
if len(positive) == 1 and num_images > 1:
|
|
positive = positive * num_images
|
|
elif len(positive) != num_images:
|
|
raise ValueError(
|
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
|
)
|
|
|
|
with torch.inference_mode(False):
|
|
lora_sd = {}
|
|
generator = torch.Generator()
|
|
generator.manual_seed(seed)
|
|
|
|
# Load existing LoRA weights if provided
|
|
existing_weights = {}
|
|
existing_steps = 0
|
|
if existing_lora != "[None]":
|
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
|
if lora_path:
|
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
|
|
|
all_weight_adapters = []
|
|
for n, m in mp.model.named_modules():
|
|
if hasattr(m, "weight_function"):
|
|
if m.weight is not None:
|
|
key = "{}.weight".format(n)
|
|
shape = m.weight.shape
|
|
if len(shape) >= 2:
|
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
|
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
|
for adapter_cls in adapters:
|
|
existing_adapter = adapter_cls.load(
|
|
n, existing_weights, alpha, dora_scale
|
|
)
|
|
if existing_adapter is not None:
|
|
break
|
|
else:
|
|
existing_adapter = None
|
|
adapter_cls = adapter_maps[algorithm]
|
|
|
|
if existing_adapter is not None:
|
|
train_adapter = existing_adapter.to_train().to(
|
|
lora_dtype
|
|
)
|
|
else:
|
|
# Use LoRA with alpha=1.0 by default
|
|
train_adapter = adapter_cls.create_train(
|
|
m.weight, rank=rank, alpha=1.0
|
|
).to(lora_dtype)
|
|
for name, parameter in train_adapter.named_parameters():
|
|
lora_sd[f"{n}.{name}"] = parameter
|
|
|
|
mp.add_weight_wrapper(key, train_adapter)
|
|
all_weight_adapters.append(train_adapter)
|
|
else:
|
|
diff = torch.nn.Parameter(
|
|
torch.zeros(
|
|
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
|
)
|
|
)
|
|
diff_module = BiasDiff(diff)
|
|
mp.add_weight_wrapper(key, BiasDiff(diff))
|
|
all_weight_adapters.append(diff_module)
|
|
lora_sd["{}.diff".format(n)] = diff
|
|
if hasattr(m, "bias") and m.bias is not None:
|
|
key = "{}.bias".format(n)
|
|
bias = torch.nn.Parameter(
|
|
torch.zeros(
|
|
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
|
)
|
|
)
|
|
bias_module = BiasDiff(bias)
|
|
lora_sd["{}.diff_b".format(n)] = bias
|
|
mp.add_weight_wrapper(key, BiasDiff(bias))
|
|
all_weight_adapters.append(bias_module)
|
|
|
|
if optimizer == "Adam":
|
|
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
|
elif optimizer == "AdamW":
|
|
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
|
elif optimizer == "SGD":
|
|
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
|
elif optimizer == "RMSprop":
|
|
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
|
|
|
# Setup loss function based on selection
|
|
if loss_function == "MSE":
|
|
criterion = torch.nn.MSELoss()
|
|
elif loss_function == "L1":
|
|
criterion = torch.nn.L1Loss()
|
|
elif loss_function == "Huber":
|
|
criterion = torch.nn.HuberLoss()
|
|
elif loss_function == "SmoothL1":
|
|
criterion = torch.nn.SmoothL1Loss()
|
|
|
|
# setup models
|
|
if gradient_checkpointing:
|
|
for m in find_all_highest_child_module_with_forward(
|
|
mp.model.diffusion_model
|
|
):
|
|
patch(m)
|
|
mp.model.requires_grad_(False)
|
|
comfy.model_management.load_models_gpu(
|
|
[mp], memory_required=1e20, force_full_load=True
|
|
)
|
|
|
|
# Setup sampler and guider like in test script
|
|
loss_map = {"loss": []}
|
|
|
|
def loss_callback(loss):
|
|
loss_map["loss"].append(loss)
|
|
|
|
train_sampler = TrainSampler(
|
|
criterion,
|
|
optimizer,
|
|
loss_callback=loss_callback,
|
|
batch_size=batch_size,
|
|
grad_acc=grad_accumulation_steps,
|
|
total_steps=steps * grad_accumulation_steps,
|
|
seed=seed,
|
|
training_dtype=dtype,
|
|
real_dataset=latents if multi_res else None,
|
|
)
|
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
|
guider.set_conds(positive) # Set conditioning from input
|
|
|
|
# Training loop
|
|
try:
|
|
# Generate dummy sigmas and noise
|
|
sigmas = torch.tensor(range(num_images))
|
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
|
if multi_res:
|
|
# use first latent as dummy latent if multi_res
|
|
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
|
guider.sample(
|
|
noise.generate_noise({"samples": latents}),
|
|
latents,
|
|
train_sampler,
|
|
sigmas,
|
|
seed=noise.seed,
|
|
)
|
|
finally:
|
|
for m in mp.model.modules():
|
|
unpatch(m)
|
|
del train_sampler, optimizer
|
|
|
|
for adapter in all_weight_adapters:
|
|
adapter.requires_grad_(False)
|
|
|
|
for param in lora_sd:
|
|
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
|
|
|
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
|
|
|
|
|
class LoraModelLoader(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoraModelLoader",
|
|
display_name="Load LoRA Model",
|
|
category="loaders",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Model.Input(
|
|
"model", tooltip="The diffusion model the LoRA will be applied to."
|
|
),
|
|
io.Custom("LORA_MODEL").Input(
|
|
"lora", tooltip="The LoRA model to apply to the diffusion model."
|
|
),
|
|
io.Float.Input(
|
|
"strength_model",
|
|
default=1.0,
|
|
min=-100.0,
|
|
max=100.0,
|
|
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(
|
|
display_name="model", tooltip="The modified diffusion model."
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, lora, strength_model):
|
|
if strength_model == 0:
|
|
return io.NodeOutput(model)
|
|
|
|
model_lora, _ = comfy.sd.load_lora_for_models(
|
|
model, None, lora, strength_model, 0
|
|
)
|
|
return io.NodeOutput(model_lora)
|
|
|
|
|
|
class SaveLoRA(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveLoRA",
|
|
display_name="Save LoRA Weights",
|
|
category="loaders",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
inputs=[
|
|
io.Custom("LORA_MODEL").Input(
|
|
"lora",
|
|
tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
|
|
),
|
|
io.String.Input(
|
|
"prefix",
|
|
default="loras/ComfyUI_trained_lora",
|
|
tooltip="The prefix to use for the saved LoRA file.",
|
|
),
|
|
io.Int.Input(
|
|
"steps",
|
|
optional=True,
|
|
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
|
),
|
|
],
|
|
outputs=[],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, lora, prefix, steps=None):
|
|
output_dir = folder_paths.get_output_directory()
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
|
folder_paths.get_save_image_path(prefix, output_dir)
|
|
)
|
|
if steps is None:
|
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
|
else:
|
|
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
|
safetensors.torch.save_file(lora, output_checkpoint)
|
|
return io.NodeOutput()
|
|
|
|
|
|
class LossGraphNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LossGraphNode",
|
|
display_name="Plot Loss Graph",
|
|
category="training",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
inputs=[
|
|
io.Custom("LOSS_MAP").Input(
|
|
"loss", tooltip="Loss map from training node."
|
|
),
|
|
io.String.Input(
|
|
"filename_prefix",
|
|
default="loss_graph",
|
|
tooltip="Prefix for the saved loss graph image.",
|
|
),
|
|
],
|
|
outputs=[],
|
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
loss_values = loss["loss"]
|
|
width, height = 800, 480
|
|
margin = 40
|
|
|
|
img = Image.new(
|
|
"RGB", (width + margin, height + margin), "white"
|
|
) # Extend canvas
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
min_loss, max_loss = min(loss_values), max(loss_values)
|
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
|
|
|
steps = len(loss_values)
|
|
|
|
prev_point = (margin, height - int(scaled_loss[0] * height))
|
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
|
x = margin + int(i / steps * width) # Scale X properly
|
|
y = height - int(l * height)
|
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
|
prev_point = (x, y)
|
|
|
|
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
|
draw.line(
|
|
[(margin, height), (width + margin, height)], fill="black", width=2
|
|
) # X-axis
|
|
|
|
font = None
|
|
try:
|
|
font = ImageFont.truetype("arial.ttf", 12)
|
|
except IOError:
|
|
font = ImageFont.load_default()
|
|
|
|
# Add axis labels
|
|
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
|
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
|
|
|
# Add min/max loss values
|
|
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
|
draw.text(
|
|
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
|
)
|
|
|
|
# Convert PIL image to tensor for PreviewImage
|
|
img_array = np.array(img).astype(np.float32) / 255.0
|
|
img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
|
|
|
|
# Return preview UI
|
|
return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
|
|
|
|
|
|
# ========== Extension Setup ==========
|
|
|
|
|
|
class TrainingExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
TrainLoraNode,
|
|
LoraModelLoader,
|
|
SaveLoRA,
|
|
LossGraphNode,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> TrainingExtension:
|
|
return TrainingExtension()
|