mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-02 19:13:34 +08:00
latent input for the control sampler, allows vid2vid along with pose input
This commit is contained in:
parent
3c8183ac65
commit
7af6666c67
@ -214,7 +214,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
set_pab_manager(pab_config)
|
set_pab_manager(pab_config)
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
|
||||||
):
|
):
|
||||||
shape = (
|
shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -228,15 +228,28 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
)
|
)
|
||||||
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
latents = noise
|
||||||
else:
|
else:
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||||
|
latent_timestep = timesteps[:1]
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||||
latents = latents * self.scheduler.init_noise_sigma
|
frames_needed = noise.shape[1]
|
||||||
return latents
|
current_frames = latents.shape[1]
|
||||||
|
|
||||||
|
if frames_needed > current_frames:
|
||||||
|
repeat_factor = frames_needed // current_frames
|
||||||
|
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
|
||||||
|
latents = torch.cat((latents, additional_frame), dim=1)
|
||||||
|
elif frames_needed < current_frames:
|
||||||
|
latents = latents[:, :frames_needed, :, :, :]
|
||||||
|
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
|
||||||
|
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
return latents, timesteps, noise
|
||||||
|
|
||||||
def prepare_control_latents(
|
def prepare_control_latents(
|
||||||
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||||
@ -452,6 +465,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
guidance_scale: float = 6,
|
guidance_scale: float = 6,
|
||||||
use_dynamic_cfg: bool = False,
|
use_dynamic_cfg: bool = False,
|
||||||
|
denoise_strength: float = 1.0,
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
@ -601,7 +615,7 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
|
|
||||||
# 5. Prepare latents.
|
# 5. Prepare latents.
|
||||||
latent_channels = self.vae.config.latent_channels
|
latent_channels = self.vae.config.latent_channels
|
||||||
latents = self.prepare_latents(
|
latents, timesteps, noise = self.prepare_latents(
|
||||||
batch_size * num_videos_per_prompt,
|
batch_size * num_videos_per_prompt,
|
||||||
latent_channels,
|
latent_channels,
|
||||||
num_frames,
|
num_frames,
|
||||||
@ -610,6 +624,9 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
|
|||||||
self.vae.dtype,
|
self.vae.dtype,
|
||||||
device,
|
device,
|
||||||
generator,
|
generator,
|
||||||
|
timesteps,
|
||||||
|
denoise_strength,
|
||||||
|
num_inference_steps,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
if comfyui_progressbar:
|
if comfyui_progressbar:
|
||||||
|
|||||||
34
nodes.py
34
nodes.py
@ -746,7 +746,7 @@ class CogVideoImageEncode:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
||||||
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
|
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||||
"mask": ("MASK", ),
|
"mask": ("MASK", ),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -756,7 +756,7 @@ class CogVideoImageEncode:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None):
|
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
@ -764,15 +764,17 @@ class CogVideoImageEncode:
|
|||||||
B, H, W, C = image.shape
|
B, H, W, C = image.shape
|
||||||
|
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
|
vae.enable_slicing()
|
||||||
|
|
||||||
if enable_vae_slicing:
|
if enable_tiling:
|
||||||
vae.enable_slicing()
|
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
|
||||||
else:
|
enable_vae_encode_tiling(vae)
|
||||||
vae.disable_slicing()
|
|
||||||
|
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
|
vae._clear_fake_context_parallel_cache()
|
||||||
|
|
||||||
input_image = image.clone()
|
input_image = image.clone()
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
pipeline["pipe"].original_mask = mask
|
pipeline["pipe"].original_mask = mask
|
||||||
@ -1211,8 +1213,8 @@ class CogVideoControlImageEncode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("COGCONTROL_LATENTS",)
|
RETURN_TYPES = ("COGCONTROL_LATENTS", "INT", "INT",)
|
||||||
RETURN_NAMES = ("control_latents",)
|
RETURN_NAMES = ("control_latents", "width", "height")
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
@ -1271,7 +1273,7 @@ class CogVideoControlImageEncode:
|
|||||||
"width" : width,
|
"width" : width,
|
||||||
}
|
}
|
||||||
|
|
||||||
return (control_latents, )
|
return (control_latents, width, height)
|
||||||
|
|
||||||
class CogVideoXFunControlSampler:
|
class CogVideoXFunControlSampler:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1309,7 +1311,10 @@ class CogVideoXFunControlSampler:
|
|||||||
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"control_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}),
|
"t_tile_length": ("INT", {"default": 48, "min": 2, "max": 128, "step": 1, "tooltip": "Length of temporal tiles for extending generations, only in effect with the tiled samplers"}),
|
||||||
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
|
"t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1, "tooltip": "Overlap of temporal tiling"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"samples": ("LATENT", ),
|
||||||
|
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1318,8 +1323,9 @@ class CogVideoXFunControlSampler:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler,
|
def process(self, pipeline, positive, negative, seed, steps, cfg, scheduler, control_latents,
|
||||||
control_latents, control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,):
|
control_strength=1.0, control_start_percent=0.0, control_end_percent=1.0, t_tile_length=16, t_tile_overlap=8,
|
||||||
|
samples=None, denoise_strength=1.0):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
pipe = pipeline["pipe"]
|
pipe = pipeline["pipe"]
|
||||||
@ -1367,7 +1373,9 @@ class CogVideoXFunControlSampler:
|
|||||||
control_end_percent=control_end_percent,
|
control_end_percent=control_end_percent,
|
||||||
t_tile_length=t_tile_length,
|
t_tile_length=t_tile_length,
|
||||||
t_tile_overlap=t_tile_overlap,
|
t_tile_overlap=t_tile_overlap,
|
||||||
scheduler_name=scheduler
|
scheduler_name=scheduler,
|
||||||
|
latents=samples["samples"] if samples is not None else None,
|
||||||
|
denoise_strength=denoise_strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (pipeline, {"samples": latents})
|
return (pipeline, {"samples": latents})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user