diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index f6e1241..9bd8813 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -894,8 +894,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): if tora is not None: trajectory_length = tora["video_flow_features"].shape[1] logger.info(f"Tora trajectory length: {trajectory_length}") + logger.info(f"Tora trajectory shape: {tora["video_flow_features"].shape}") + logger.info(f"latents shape: {latents.shape}") if trajectory_length != latents.shape[1]: - raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}") + raise ValueError(f"Tora trajectory length {trajectory_length} does not match latent count {latents.shape[2]}") for module in self.transformer.fuser_list: for param in module.parameters(): param.data = param.data.to(device) @@ -903,6 +905,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + from ..latent_preview import prepare_callback + callback = prepare_callback(self.transformer, num_inference_steps) + with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ old_pred_original_sample = None @@ -1120,21 +1125,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): ) latents = latents.to(prompt_embeds.dtype) - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if comfyui_progressbar: - pbar.update(1) + if callback is not None: + callback(i, latents.detach()[-1], None, num_inference_steps) + else: + pbar.update(1) # if output_type == "numpy": # video = self.decode_latents(latents) diff --git a/latent_preview.py b/latent_preview.py new file mode 100644 index 0000000..5ed78d6 --- /dev/null +++ b/latent_preview.py @@ -0,0 +1,79 @@ +import io + +import torch +from PIL import Image +import struct +import numpy as np +from comfy.cli_args import args, LatentPreviewMethod +from comfy.taesd.taesd import TAESD +import comfy.model_management +import folder_paths +import comfy.utils +import logging + +MAX_PREVIEW_RESOLUTION = args.preview_size + +def preview_to_image(latent_image): + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) + + return Image.fromarray(latents_ubyte.numpy()) + +class LatentPreviewer: + def decode_latent_to_preview(self, x0): + pass + + def decode_latent_to_preview_image(self, preview_format, x0): + preview_image = self.decode_latent_to_preview(x0) + return ("GIF", preview_image, MAX_PREVIEW_RESOLUTION) + +class Latent2RGBPreviewer(LatentPreviewer): + def __init__(self): + latent_rgb_factors = [[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]] + + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) + self.latent_rgb_factors_bias = None + # if latent_rgb_factors_bias is not None: + # self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + + def decode_latent_to_preview(self, x0): + self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) + if self.latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) + + latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, + bias=self.latent_rgb_factors_bias) + return preview_to_image(latent_image) + + +def get_previewer(): + previewer = None + method = args.preview_method + if method != LatentPreviewMethod.NoPreviews: + # TODO previewer method + + if method == LatentPreviewMethod.Auto: + method = LatentPreviewMethod.Latent2RGB + + if previewer is None: + previewer = Latent2RGBPreviewer() + return previewer + +def prepare_callback(model, steps, x0_output_dict=None): + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + + previewer = get_previewer() + + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x, total_steps): + if x0_output_dict is not None: + x0_output_dict["x0"] = x0 + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) + return callback + diff --git a/nodes.py b/nodes.py index 5eb5bdf..a742dc5 100644 --- a/nodes.py +++ b/nodes.py @@ -827,7 +827,7 @@ class CogVideoDecode: ) else: vae.disable_tiling() - latents = latents.to(vae.dtype) + latents = latents.to(vae.dtype).to(device) latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / vae.config.scaling_factor * latents try: @@ -1296,6 +1296,70 @@ class CogVideoXFunControlSampler: return (pipeline, {"samples": latents}) +class CogVideoLatentPreview: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}), + "max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), + "r_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}), + "g_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}), + "b_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}), + }, + } + + RETURN_TYPES = ("IMAGE", "STRING", ) + RETURN_NAMES = ("images", "latent_rgb_factors",) + FUNCTION = "sample" + CATEGORY = "PyramidFlowWrapper" + + def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias): + mm.soft_empty_cache() + + latents = samples["samples"].clone() + print("in sample", latents.shape) + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + #[[0.0658900170023352, 0.04687556512203313, -0.056971557475649186], [-0.01265770449940036, -0.02814809569100843, -0.0768912512529372], [0.061456544746314665, 0.0005511617552452358, -0.0652574975291287], [-0.09020669168815276, -0.004755440180558637, -0.023763970904494294], [0.031766964513999865, -0.030959599938418375, 0.08654669098083616], [-0.005981764690055846, -0.08809119252349802, -0.06439852368217663], [-0.0212114426433989, 0.08894281999597677, 0.05155629477559985], [-0.013947446911030725, -0.08987475069900677, -0.08923124751217484], [-0.08235967967978511, 0.07268025379974379, 0.08830486164536037], [-0.08052049179735378, -0.050116143175332195, 0.02023752569687405], [-0.07607527759162447, 0.06827156419895981, 0.08678111754261035], [-0.04689089232553825, 0.017294986041038893, -0.10280492336438908], [-0.06105783150270304, 0.07311850680875913, 0.019995735372550075], [-0.09232589996527711, -0.012869815059053047, -0.04355587834255975], [-0.06679931010802251, 0.018399815879067458, 0.06802404982033876], [-0.013062632927118165, -0.04292991477896661, 0.07476243356192845]] + latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]] + import random + random.seed(seed) + latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(16)] + out_factors = latent_rgb_factors + print(latent_rgb_factors) + + latent_rgb_factors_bias = [0.085, 0.137, 0.158] + #latent_rgb_factors_bias = [r_bias, g_bias, b_bias] + + latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1) + latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) + + print("latent_rgb_factors", latent_rgb_factors.shape) + + latent_images = [] + for t in range(latents.shape[2]): + latent = latents[:, :, t, :, :] + latent = latent[0].permute(1, 2, 0) + latent_image = torch.nn.functional.linear( + latent, + latent_rgb_factors, + bias=latent_rgb_factors_bias + ) + latent_images.append(latent_image) + latent_images = torch.stack(latent_images, dim=0) + print("latent_images", latent_images.shape) + latent_images_min = latent_images.min() + latent_images_max = latent_images.max() + latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min) + + return (latent_images.float().cpu(), out_factors) + NODE_CLASS_MAPPINGS = { "CogVideoSampler": CogVideoSampler, "CogVideoDecode": CogVideoDecode, @@ -1316,7 +1380,8 @@ NODE_CLASS_MAPPINGS = { "ToraEncodeTrajectory": ToraEncodeTrajectory, "ToraEncodeOpticalFlow": ToraEncodeOpticalFlow, "CogVideoXFasterCache": CogVideoXFasterCache, - "CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket + "CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket, + "CogVideoLatentPreview": CogVideoLatentPreview } NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoSampler": "CogVideo Sampler", @@ -1337,5 +1402,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ToraEncodeTrajectory": "Tora Encode Trajectory", "ToraEncodeOpticalFlow": "Tora Encode OpticalFlow", "CogVideoXFasterCache": "CogVideoX FasterCache", - "CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket" + "CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket", + "CogVideoLatentPreview": "CogVideo LatentPreview" } \ No newline at end of file