latent preview tests

This commit is contained in:
kijai 2024-11-07 11:26:19 +02:00
parent 0e8f8140e4
commit 666f7832f9
3 changed files with 158 additions and 16 deletions

View File

@ -894,8 +894,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
if tora is not None:
trajectory_length = tora["video_flow_features"].shape[1]
logger.info(f"Tora trajectory length: {trajectory_length}")
logger.info(f"Tora trajectory shape: {tora["video_flow_features"].shape}")
logger.info(f"latents shape: {latents.shape}")
if trajectory_length != latents.shape[1]:
raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
raise ValueError(f"Tora trajectory length {trajectory_length} does not match latent count {latents.shape[2]}")
for module in self.transformer.fuser_list:
for param in module.parameters():
param.data = param.data.to(device)
@ -903,6 +905,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
from ..latent_preview import prepare_callback
callback = prepare_callback(self.transformer, num_inference_steps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
@ -1120,21 +1125,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if comfyui_progressbar:
pbar.update(1)
if callback is not None:
callback(i, latents.detach()[-1], None, num_inference_steps)
else:
pbar.update(1)
# if output_type == "numpy":
# video = self.decode_latents(latents)

79
latent_preview.py Normal file
View File

@ -0,0 +1,79 @@
import io
import torch
from PIL import Image
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.model_management
import folder_paths
import comfy.utils
import logging
MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
return Image.fromarray(latents_ubyte.numpy())
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
return ("GIF", preview_image, MAX_PREVIEW_RESOLUTION)
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
latent_rgb_factors = [[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
# if latent_rgb_factors_bias is not None:
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
bias=self.latent_rgb_factors_bias)
return preview_to_image(latent_image)
def get_previewer():
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer method
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if previewer is None:
previewer = Latent2RGBPreviewer()
return previewer
def prepare_callback(model, steps, x0_output_dict=None):
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = get_previewer()
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
if x0_output_dict is not None:
x0_output_dict["x0"] = x0
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback

View File

@ -827,7 +827,7 @@ class CogVideoDecode:
)
else:
vae.disable_tiling()
latents = latents.to(vae.dtype)
latents = latents.to(vae.dtype).to(device)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents
try:
@ -1296,6 +1296,70 @@ class CogVideoXFunControlSampler:
return (pipeline, {"samples": latents})
class CogVideoLatentPreview:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}),
"max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
"r_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
"g_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
"b_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
},
}
RETURN_TYPES = ("IMAGE", "STRING", )
RETURN_NAMES = ("images", "latent_rgb_factors",)
FUNCTION = "sample"
CATEGORY = "PyramidFlowWrapper"
def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
mm.soft_empty_cache()
latents = samples["samples"].clone()
print("in sample", latents.shape)
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
#[[0.0658900170023352, 0.04687556512203313, -0.056971557475649186], [-0.01265770449940036, -0.02814809569100843, -0.0768912512529372], [0.061456544746314665, 0.0005511617552452358, -0.0652574975291287], [-0.09020669168815276, -0.004755440180558637, -0.023763970904494294], [0.031766964513999865, -0.030959599938418375, 0.08654669098083616], [-0.005981764690055846, -0.08809119252349802, -0.06439852368217663], [-0.0212114426433989, 0.08894281999597677, 0.05155629477559985], [-0.013947446911030725, -0.08987475069900677, -0.08923124751217484], [-0.08235967967978511, 0.07268025379974379, 0.08830486164536037], [-0.08052049179735378, -0.050116143175332195, 0.02023752569687405], [-0.07607527759162447, 0.06827156419895981, 0.08678111754261035], [-0.04689089232553825, 0.017294986041038893, -0.10280492336438908], [-0.06105783150270304, 0.07311850680875913, 0.019995735372550075], [-0.09232589996527711, -0.012869815059053047, -0.04355587834255975], [-0.06679931010802251, 0.018399815879067458, 0.06802404982033876], [-0.013062632927118165, -0.04292991477896661, 0.07476243356192845]]
latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
import random
random.seed(seed)
latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(16)]
out_factors = latent_rgb_factors
print(latent_rgb_factors)
latent_rgb_factors_bias = [0.085, 0.137, 0.158]
#latent_rgb_factors_bias = [r_bias, g_bias, b_bias]
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
print("latent_rgb_factors", latent_rgb_factors.shape)
latent_images = []
for t in range(latents.shape[2]):
latent = latents[:, :, t, :, :]
latent = latent[0].permute(1, 2, 0)
latent_image = torch.nn.functional.linear(
latent,
latent_rgb_factors,
bias=latent_rgb_factors_bias
)
latent_images.append(latent_image)
latent_images = torch.stack(latent_images, dim=0)
print("latent_images", latent_images.shape)
latent_images_min = latent_images.min()
latent_images_max = latent_images.max()
latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min)
return (latent_images.float().cpu(), out_factors)
NODE_CLASS_MAPPINGS = {
"CogVideoSampler": CogVideoSampler,
"CogVideoDecode": CogVideoDecode,
@ -1316,7 +1380,8 @@ NODE_CLASS_MAPPINGS = {
"ToraEncodeTrajectory": ToraEncodeTrajectory,
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
"CogVideoXFasterCache": CogVideoXFasterCache,
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket,
"CogVideoLatentPreview": CogVideoLatentPreview
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
@ -1337,5 +1402,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
"CogVideoXFasterCache": "CogVideoX FasterCache",
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket"
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket",
"CogVideoLatentPreview": "CogVideo LatentPreview"
}