support input latents for vid2vid

This commit is contained in:
kijai 2024-08-06 22:02:29 +03:00
parent 5755a8c7f3
commit 0d69d73c60
2 changed files with 124 additions and 186 deletions

104
nodes.py
View File

@ -122,21 +122,82 @@ class CogVideoTextEncode:
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
return (embeds, )
class CogVideoSampler:
class CogVideoImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}),
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
"steps": ("INT", {"default": 25, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"image": ("IMAGE", ),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, image):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
vae = pipeline["pipe"].vae
vae.to(device)
image = image * 2.0 - 1.0
image = image.to(vae.dtype).to(device)
image = image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
B, C, T, H, W = image.shape
chunk_size = 16
latents_list = []
# Loop through the temporal dimension in chunks of 16
for i in range(0, T, chunk_size):
# Get the chunk of 16 frames (or remaining frames if less than 16 are left)
end_index = min(i + chunk_size, T)
image_chunk = image[:, :, i:end_index, :, :] # Shape: [B, C, chunk_size, H, W]
# Encode the chunk of images
latents = vae.encode(image_chunk)
sample_mode = "sample"
if hasattr(latents, "latent_dist") and sample_mode == "sample":
latents = latents.latent_dist.sample(generator)
elif hasattr(latents, "latent_dist") and sample_mode == "argmax":
latents = latents.latent_dist.mode()
elif hasattr(latents, "latents"):
latents = latents.latents
latents = vae.config.scaling_factor * latents
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
latents_list.append(latents)
# Concatenate all the chunks along the temporal dimension
final_latents = torch.cat(latents_list, dim=1)
print("final latents: ", final_latents.shape)
vae.to(offload_device)
return ({"samples": final_latents}, )
class CogVideoSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pipeline": ("COGVIDEOPIPE",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}),
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
"steps": ("INT", {"default": 25, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
},
"optional": {
"samples": ("LATENT", ),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
@ -145,7 +206,7 @@ class CogVideoSampler:
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames):
def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames, samples=None, denoise_strength=1.0):
mm.soft_empty_cache()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -162,6 +223,8 @@ class CogVideoSampler:
num_frames = num_frames,
fps = fps,
guidance_scale=cfg,
latents=samples["samples"] if samples is not None else None,
denoise_strength=denoise_strength,
prompt_embeds=positive.to(dtype).to(device),
negative_prompt_embeds=negative.to(dtype).to(device),
#negative_prompt_embeds=torch.zeros_like(embeds),
@ -198,11 +261,15 @@ class CogVideoDecode:
vae = pipeline["pipe"].vae
vae.to(device)
num_frames = pipeline["num_frames"]
fps = pipeline["fps"]
if "num_frames" in pipeline:
num_frames = pipeline["num_frames"]
fps = pipeline["fps"]
else:
num_frames = latents.shape[2]
fps = 8
num_seconds = num_frames // fps
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents
@ -217,6 +284,7 @@ class CogVideoDecode:
vae.to(offload_device)
frames = torch.cat(frames, dim=2)
print(frames.min(), frames.max())
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
print(video.shape)
video = video[0].permute(0, 2, 3, 1).cpu().float()
@ -229,11 +297,13 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadCogVideoModel": DownloadAndLoadCogVideoModel,
"CogVideoSampler": CogVideoSampler,
"CogVideoDecode": CogVideoDecode,
"CogVideoTextEncode": CogVideoTextEncode
"CogVideoTextEncode": CogVideoTextEncode,
"CogVideoImageEncode": CogVideoImageEncode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
"CogVideoSampler": "CogVideo Sampler",
"CogVideoDecode": "CogVideo Decode",
"CogVideoTextEncode": "CogVideo TextEncode"
"CogVideoTextEncode": "CogVideo TextEncode",
"CogVideoImageEncode": "CogVideo ImageEncode"
}

View File

@ -18,7 +18,6 @@ from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
@ -165,8 +164,6 @@ class CogVideoXPipeline(DiffusionPipeline):
def __init__(
self,
tokenizer: T5Tokenizer,
#text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
@ -174,7 +171,7 @@ class CogVideoXPipeline(DiffusionPipeline):
super().__init__()
self.register_modules(
tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler
vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@ -182,136 +179,11 @@ class CogVideoXPipeline(DiffusionPipeline):
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
#prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
):
shape = (
batch_size,
@ -328,12 +200,27 @@ class CogVideoXPipeline(DiffusionPipeline):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
else:
latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
latent_timestep = timesteps[:1]
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
frames_needed = noise.shape[1]
current_frames = latents.shape[1]
if frames_needed > current_frames:
repeat_factor = frames_needed // current_frames
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
latents = torch.cat((latents, additional_frame), dim=1)
elif frames_needed < current_frames:
latents = latents[:, :frames_needed, :, :, :]
# scale the initial noise by the standard deviation required by the scheduler
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
latents = latents * self.scheduler.init_noise_sigma
return latents
return latents, timesteps
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
@ -372,10 +259,8 @@ class CogVideoXPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
@ -389,29 +274,6 @@ class CogVideoXPipeline(DiffusionPipeline):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
@ -420,6 +282,16 @@ class CogVideoXPipeline(DiffusionPipeline):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps.to(device), num_inference_steps - t_start
@property
def guidance_scale(self):
@ -444,8 +316,6 @@ class CogVideoXPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 48,
@ -453,6 +323,7 @@ class CogVideoXPipeline(DiffusionPipeline):
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
denoise_strength: float = 1.0,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@ -553,10 +424,8 @@ class CogVideoXPipeline(DiffusionPipeline):
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
@ -565,12 +434,8 @@ class CogVideoXPipeline(DiffusionPipeline):
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = prompt_embeds.shape[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -587,7 +452,7 @@ class CogVideoXPipeline(DiffusionPipeline):
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents(
latents, timesteps = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
@ -596,7 +461,10 @@ class CogVideoXPipeline(DiffusionPipeline):
prompt_embeds.dtype,
device,
generator,
latents,
timesteps,
denoise_strength,
num_inference_steps,
latents
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline