mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
latent preview tests
This commit is contained in:
parent
0e8f8140e4
commit
666f7832f9
@ -894,8 +894,10 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
if tora is not None:
|
||||
trajectory_length = tora["video_flow_features"].shape[1]
|
||||
logger.info(f"Tora trajectory length: {trajectory_length}")
|
||||
logger.info(f"Tora trajectory shape: {tora["video_flow_features"].shape}")
|
||||
logger.info(f"latents shape: {latents.shape}")
|
||||
if trajectory_length != latents.shape[1]:
|
||||
raise ValueError(f"Tora trajectory length {trajectory_length} does not match inpaint_latents count {latents.shape[2]}")
|
||||
raise ValueError(f"Tora trajectory length {trajectory_length} does not match latent count {latents.shape[2]}")
|
||||
for module in self.transformer.fuser_list:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(device)
|
||||
@ -903,6 +905,9 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
from ..latent_preview import prepare_callback
|
||||
callback = prepare_callback(self.transformer, num_inference_steps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
@ -1120,21 +1125,13 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if comfyui_progressbar:
|
||||
pbar.update(1)
|
||||
if callback is not None:
|
||||
callback(i, latents.detach()[-1], None, num_inference_steps)
|
||||
else:
|
||||
pbar.update(1)
|
||||
|
||||
# if output_type == "numpy":
|
||||
# video = self.decode_latents(latents)
|
||||
|
||||
79
latent_preview.py
Normal file
79
latent_preview.py
Normal file
@ -0,0 +1,79 @@
|
||||
import io
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from comfy.taesd.taesd import TAESD
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
|
||||
def preview_to_image(latent_image):
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
class LatentPreviewer:
|
||||
def decode_latent_to_preview(self, x0):
|
||||
pass
|
||||
|
||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||
preview_image = self.decode_latent_to_preview(x0)
|
||||
return ("GIF", preview_image, MAX_PREVIEW_RESOLUTION)
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self):
|
||||
latent_rgb_factors = [[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
|
||||
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||
self.latent_rgb_factors_bias = None
|
||||
# if latent_rgb_factors_bias is not None:
|
||||
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
|
||||
bias=self.latent_rgb_factors_bias)
|
||||
return preview_to_image(latent_image)
|
||||
|
||||
|
||||
def get_previewer():
|
||||
previewer = None
|
||||
method = args.preview_method
|
||||
if method != LatentPreviewMethod.NoPreviews:
|
||||
# TODO previewer method
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
if previewer is None:
|
||||
previewer = Latent2RGBPreviewer()
|
||||
return previewer
|
||||
|
||||
def prepare_callback(model, steps, x0_output_dict=None):
|
||||
preview_format = "JPEG"
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
previewer = get_previewer()
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
if x0_output_dict is not None:
|
||||
x0_output_dict["x0"] = x0
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
return callback
|
||||
|
||||
72
nodes.py
72
nodes.py
@ -827,7 +827,7 @@ class CogVideoDecode:
|
||||
)
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
latents = latents.to(vae.dtype)
|
||||
latents = latents.to(vae.dtype).to(device)
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / vae.config.scaling_factor * latents
|
||||
try:
|
||||
@ -1296,6 +1296,70 @@ class CogVideoXFunControlSampler:
|
||||
|
||||
return (pipeline, {"samples": latents})
|
||||
|
||||
class CogVideoLatentPreview:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}),
|
||||
"max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"r_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
|
||||
"g_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
|
||||
"b_bias": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "STRING", )
|
||||
RETURN_NAMES = ("images", "latent_rgb_factors",)
|
||||
FUNCTION = "sample"
|
||||
CATEGORY = "PyramidFlowWrapper"
|
||||
|
||||
def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
|
||||
mm.soft_empty_cache()
|
||||
|
||||
latents = samples["samples"].clone()
|
||||
print("in sample", latents.shape)
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
#[[0.0658900170023352, 0.04687556512203313, -0.056971557475649186], [-0.01265770449940036, -0.02814809569100843, -0.0768912512529372], [0.061456544746314665, 0.0005511617552452358, -0.0652574975291287], [-0.09020669168815276, -0.004755440180558637, -0.023763970904494294], [0.031766964513999865, -0.030959599938418375, 0.08654669098083616], [-0.005981764690055846, -0.08809119252349802, -0.06439852368217663], [-0.0212114426433989, 0.08894281999597677, 0.05155629477559985], [-0.013947446911030725, -0.08987475069900677, -0.08923124751217484], [-0.08235967967978511, 0.07268025379974379, 0.08830486164536037], [-0.08052049179735378, -0.050116143175332195, 0.02023752569687405], [-0.07607527759162447, 0.06827156419895981, 0.08678111754261035], [-0.04689089232553825, 0.017294986041038893, -0.10280492336438908], [-0.06105783150270304, 0.07311850680875913, 0.019995735372550075], [-0.09232589996527711, -0.012869815059053047, -0.04355587834255975], [-0.06679931010802251, 0.018399815879067458, 0.06802404982033876], [-0.013062632927118165, -0.04292991477896661, 0.07476243356192845]]
|
||||
latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
|
||||
import random
|
||||
random.seed(seed)
|
||||
latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(16)]
|
||||
out_factors = latent_rgb_factors
|
||||
print(latent_rgb_factors)
|
||||
|
||||
latent_rgb_factors_bias = [0.085, 0.137, 0.158]
|
||||
#latent_rgb_factors_bias = [r_bias, g_bias, b_bias]
|
||||
|
||||
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
||||
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
print("latent_rgb_factors", latent_rgb_factors.shape)
|
||||
|
||||
latent_images = []
|
||||
for t in range(latents.shape[2]):
|
||||
latent = latents[:, :, t, :, :]
|
||||
latent = latent[0].permute(1, 2, 0)
|
||||
latent_image = torch.nn.functional.linear(
|
||||
latent,
|
||||
latent_rgb_factors,
|
||||
bias=latent_rgb_factors_bias
|
||||
)
|
||||
latent_images.append(latent_image)
|
||||
latent_images = torch.stack(latent_images, dim=0)
|
||||
print("latent_images", latent_images.shape)
|
||||
latent_images_min = latent_images.min()
|
||||
latent_images_max = latent_images.max()
|
||||
latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min)
|
||||
|
||||
return (latent_images.float().cpu(), out_factors)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoSampler": CogVideoSampler,
|
||||
"CogVideoDecode": CogVideoDecode,
|
||||
@ -1316,7 +1380,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
|
||||
"CogVideoXFasterCache": CogVideoXFasterCache,
|
||||
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket
|
||||
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket,
|
||||
"CogVideoLatentPreview": CogVideoLatentPreview
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoSampler": "CogVideo Sampler",
|
||||
@ -1337,5 +1402,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||
"CogVideoXFasterCache": "CogVideoX FasterCache",
|
||||
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket"
|
||||
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket",
|
||||
"CogVideoLatentPreview": "CogVideo LatentPreview"
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user