latent input for the control sampler, allows vid2vid along with pose input

This commit is contained in:
kijai 2024-10-04 13:14:54 +03:00
parent 3c8183ac65
commit 7af6666c67
2 changed files with 45 additions and 20 deletions

View File

@ -214,7 +214,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
set_pab_manager(pab_config)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
):
shape = (
batch_size,
@ -228,15 +228,28 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = noise
else:
latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1]
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
if frames_needed > current_frames:
repeat_factor = frames_needed // current_frames
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
latents = torch.cat((latents, additional_frame), dim=1)
elif frames_needed < current_frames:
latents = latents[:, :frames_needed, :, :, :]
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
return latents, timesteps, noise
def prepare_control_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
@ -452,6 +465,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
denoise_strength: float = 1.0,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@ -601,7 +615,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
# 5. Prepare latents.
latent_channels = self.vae.config.latent_channels
latents = self.prepare_latents(
latents, timesteps, noise = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
@ -610,6 +624,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
self.vae.dtype,
device,
generator,
timesteps,
denoise_strength,
num_inference_steps,
latents,
)
if comfyui_progressbar:

View File

@ -746,7 +746,7 @@ class CogVideoImageEncode:
},
"optional": {
"chunk_size": ("INT", {"default": 16, "min": 1}),
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"mask": ("MASK", ),
},
}
@ -756,7 +756,7 @@ class CogVideoImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None):
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
@ -764,14 +764,16 @@ class CogVideoImageEncode:
B, H, W, C = image.shape
vae = pipeline["pipe"].vae
vae.enable_slicing()
if enable_vae_slicing:
vae.enable_slicing()
else:
vae.disable_slicing()
if enable_tiling:
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
enable_vae_encode_tiling(vae)
if not pipeline["cpu_offloading"]:
vae.to(device)
vae._clear_fake_context_parallel_cache()
input_image = image.clone()
if mask is not None:
@ -1211,8 +1213,8 @@ class CogVideoControlImageEncode:
},
}
RETURN_TYPES = ("COGCONTROL_LATENTS",)
RETURN_NAMES = ("control_latents",)
RETURN_TYPES = ("COGCONTROL_LATENTS", "INT", "INT",)
RETURN_NAMES = ("control_latents", "width", "height")
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
@ -1271,7 +1273,7 @@ class CogVideoControlImageEncode:
"width" : width,
}
return (control_latents, )
return (control_latents, width, height)
class CogVideoXFunControlSampler:
@classmethod
@ -1309,7 +1311,10 @@ class CogVideoXFunControlSampler:
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}),
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
},
"optional": {
"samples": ("LATENT", ),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
@ -1318,8 +1323,9 @@ class CogVideoXFunControlSampler:
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler,
control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,):
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
samples=None, denoise_strength=1.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
@ -1367,7 +1373,9 @@ class CogVideoXFunControlSampler:
control_end_percent=control_end_percent,
t_tile_length=t_tile_length,
t_tile_overlap=t_tile_overlap,
scheduler_name=scheduler
scheduler_name=scheduler,
latents=samples["samples"] if samples is not None else None,
denoise_strength=denoise_strength,
)
return (pipeline, {"samples": latents})