add FastPreview

This commit is contained in:
kijai 2024-10-12 21:03:33 +03:00
parent b5419c853c
commit 2fbed0575d
4 changed files with 193 additions and 5 deletions

View File

@ -46,6 +46,7 @@ NODE_CONFIG = {
"GetImagesFromBatchIndexed": {"class": GetImagesFromBatchIndexed, "name": "Get Images From Batch Indexed"},
"GetImageRangeFromBatch": {"class": GetImageRangeFromBatch, "name": "Get Image or Mask Range From Batch"},
"GetImageSizeAndCount": {"class": GetImageSizeAndCount, "name": "Get Image Size & Count"},
"FastPreview": {"class": FastPreview, "name": "Fast Preview"},
"ImageAndMaskPreview": {"class": ImageAndMaskPreview},
"ImageAddMulti": {"class": ImageAddMulti, "name": "Image Add Multi"},
"ImageBatchMulti": {"class": ImageBatchMulti, "name": "Image Batch Multi"},
@ -143,6 +144,7 @@ NODE_CONFIG = {
"FluxBlockLoraLoader": {"class": FluxBlockLoraLoader, "name": "Flux Block Lora Loader"},
"FluxBlockLoraSelect": {"class": FluxBlockLoraSelect, "name": "Flux Block Lora Select"},
"CustomControlNetWeightsFluxFromList": {"class": CustomControlNetWeightsFluxFromList, "name": "Custom ControlNet Weights Flux From List"},
"PatchCublasLinear": {"class": PatchCublasLinear, "name": "Patch Cublas Linear"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -3,6 +3,8 @@ import time
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import io
import base64
import random
import math
import os
@ -1445,6 +1447,7 @@ class ImageAddMulti:
{
"default": 'add'
}),
"blend_amount": ("FLOAT", {"default": 0.5, "min": 0, "max": 1, "step": 0.01}),
},
}
@ -1458,16 +1461,16 @@ You can set how many inputs the node has,
with the **inputcount** and clicking update.
"""
def add(self, inputcount, blending, **kwargs):
def add(self, inputcount, blending, blend_amount, **kwargs):
image = kwargs["image_1"]
for c in range(1, inputcount):
new_image = kwargs[f"image_{c + 1}"]
if blending == "add":
image = torch.add(image * 0.5, new_image * 0.5)
image = torch.add(image * blend_amount, new_image * blend_amount)
elif blending == "subtract":
image = torch.sub(image * 0.5, new_image * 0.5)
image = torch.sub(image * blend_amount, new_image * blend_amount)
elif blending == "multiply":
image = torch.mul(image * 0.5, new_image * 0.5)
image = torch.mul(image * blend_amount, new_image * blend_amount)
elif blending == "difference":
image = torch.sub(image, new_image)
return (image,)
@ -2042,4 +2045,36 @@ class SaveImageKJ:
return { "ui": {
"images": results },
"result": (file,) }
"result": (file,) }
to_pil_image = T.ToPILImage()
class FastPreview:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE", ),
"format": (["JPEG", "PNG", "WEBP"], {"default": "JPEG"}),
"quality" : ("INT", {"default": 75, "min": 1, "max": 100, "step": 1}),
},
}
RETURN_TYPES = ()
FUNCTION = "preview"
CATEGORY = "KJNodes/experimental"
OUTPUT_NODE = True
def preview(self, image, format, quality):
pil_image = to_pil_image(image[0].permute(2, 0, 1))
with io.BytesIO() as buffered:
pil_image.save(buffered, format=format, quality=quality)
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
return {
"ui": {"bg_image": [img_base64]},
"result": ()
}

View File

@ -2119,4 +2119,60 @@ class ModelSaveKJ:
save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint))
return {}
class PatchCublasLinear:
original_linear = None
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"enabled": ("BOOLEAN", {"default": True, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
OUTPUT_NODE = True
DESCRIPTION = "Highly experimental node that simply patches the Linear layer to use torch-cublas-hgemm, won't take effect on already loaded models!"
CATEGORY = "KJNodes/experimental"
def patch(self, model, enabled):
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
try:
from cublas_ops import CublasLinear
except ImportError:
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
class OriginalLinear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class PatchedLinear(CublasLinear, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
if enabled:
disable_weight_init.Linear = PatchedLinear
else:
disable_weight_init.Linear = OriginalLinear
return model,

95
web/js/fast_preview.js Normal file
View File

@ -0,0 +1,95 @@
import { app } from '../../../scripts/app.js'
//from melmass
export function makeUUID() {
let dt = new Date().getTime()
const uuid = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, (c) => {
const r = ((dt + Math.random() * 16) % 16) | 0
dt = Math.floor(dt / 16)
return (c === 'x' ? r : (r & 0x3) | 0x8).toString(16)
})
return uuid
}
function chainCallback(object, property, callback) {
if (object == undefined) {
//This should not happen.
console.error("Tried to add callback to non-existant object")
return;
}
if (property in object) {
const callback_orig = object[property]
object[property] = function () {
const r = callback_orig.apply(this, arguments);
callback.apply(this, arguments);
return r
};
} else {
object[property] = callback;
}
}
app.registerExtension({
name: 'KJNodes.FastPreview',
async beforeRegisterNodeDef(nodeType, nodeData) {
if (nodeData?.name === 'FastPreview') {
chainCallback(nodeType.prototype, "onNodeCreated", function () {
var element = document.createElement("div");
this.uuid = makeUUID()
element.id = `fast-preview-${this.uuid}`
this.previewWidget = this.addDOMWidget(nodeData.name, "FastPreviewWidget", element, {
serialize: false,
hideOnZoom: false,
});
this.previewer = new Previewer(this);
this.setSize([550, 550]);
this.resizable = false;
this.previewWidget.parentEl = document.createElement("div");
this.previewWidget.parentEl.className = "fast-preview";
this.previewWidget.parentEl.id = `fast-preview-${this.uuid}`
element.appendChild(this.previewWidget.parentEl);
chainCallback(this, "onExecuted", function (message) {
let bg_image = message["bg_image"];
this.properties.imgData = {
name: "bg_image",
base64: bg_image
};
this.previewer.refreshBackgroundImage(this);
});
}); // onAfterGraphConfigured
}//node created
} //before register
})//register
class Previewer {
constructor(context) {
this.node = context;
this.previousWidth = null;
this.previousHeight = null;
}
refreshBackgroundImage = () => {
const imgData = this.node?.properties?.imgData;
if (imgData?.base64) {
const base64String = imgData.base64;
const imageUrl = `data:${imgData.type};base64,${base64String}`;
const img = new Image();
img.src = imageUrl;
img.onload = () => {
const { width, height } = img;
if (width !== this.previousWidth || height !== this.previousHeight) {
this.node.setSize([width, height]);
this.previousWidth = width;
this.previousHeight = height;
}
this.node.previewWidget.element.style.backgroundImage = `url(${imageUrl})`;
};
}
};
}