diff --git a/cogvideox_fun/fun_pab_transformer_3d.py b/cogvideox_fun/fun_pab_transformer_3d.py index 079798e..cb524e8 100644 --- a/cogvideox_fun/fun_pab_transformer_3d.py +++ b/cogvideox_fun/fun_pab_transformer_3d.py @@ -453,6 +453,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, + add_noise_in_inpaint_model: bool = False, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -568,6 +569,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, inpaint_latents: Optional[torch.Tensor] = None, + control_latents: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): @@ -586,6 +588,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): # 2. Patch embedding if inpaint_latents is not None: hidden_states = torch.concat([hidden_states, inpaint_latents], 2) + if control_latents is not None: + hidden_states = torch.concat([hidden_states, control_latents], 2) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # 3. Position embedding diff --git a/cogvideox_fun/pipeline_cogvideox_control.py b/cogvideox_fun/pipeline_cogvideox_control.py new file mode 100644 index 0000000..7130d50 --- /dev/null +++ b/cogvideox_fun/pipeline_cogvideox_control.py @@ -0,0 +1,707 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from einops import rearrange + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.image_processor import VaeImageProcessor +from einops import rearrange + +from ..videosys.core.pipeline import VideoSysPipeline +from ..videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB +from ..videosys.core.pab_mgr import set_pab_manager + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoX_Fun_Pipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoX_Fun_PipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "vae->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + pab_config = None + ): + super().__init__() + + self.register_modules( + vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + if pab_config is not None: + set_pab_manager(pab_config) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + use_real=True, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + comfyui_progressbar: bool = False, + ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + self.vae.dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + else: + control_video = None + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + self.vae.dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_video_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ) + control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w") + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + control_latents=control_latents, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + # if output_type == "numpy": + # video = self.decode_latents(latents) + # elif not output_type == "latent": + # video = self.decode_latents(latents) + # video = self.video_processor.postprocess_video(video=video, output_type=output_type) + # else: + # video = latents + + # Offload all models + self.maybe_free_model_hooks() + + # if not return_dict: + # video = torch.from_numpy(video) + + return latents \ No newline at end of file diff --git a/cogvideox_fun/pipeline_cogvideox_inpaint.py b/cogvideox_fun/pipeline_cogvideox_inpaint.py index 5420762..5e56432 100644 --- a/cogvideox_fun/pipeline_cogvideox_inpaint.py +++ b/cogvideox_fun/pipeline_cogvideox_inpaint.py @@ -179,6 +179,17 @@ def resize_mask(mask, latent, process_first_frame_only=True): ) return resized_mask +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image @dataclass class CogVideoX_Fun_PipelineOutput(BaseOutput): @@ -212,7 +223,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): """ _optional_components = [] - model_cpu_offload_seq = ">vae->transformer->vae" + model_cpu_offload_seq = "vae->transformer->vae" _callback_tensor_inputs = [ "latents", @@ -319,7 +330,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): return outputs def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload @@ -338,6 +349,8 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): mask = mask * self.vae.config.scaling_factor if masked_image is not None: + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) masked_image = masked_image.to(device=device, dtype=self.vae.dtype) bs = 1 new_mask_pixel_values = [] @@ -525,6 +538,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, strength: float = 1, + noise_aug_strength: float = 0.0563, comfyui_progressbar: bool = False, ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]: """ @@ -732,6 +746,7 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline): device, generator, do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, ) mask_latents = resize_mask(1 - mask_condition, masked_video_latents) mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor diff --git a/cogvideox_fun/transformer_3d.py b/cogvideox_fun/transformer_3d.py index b80af91..88c8013 100644 --- a/cogvideox_fun/transformer_3d.py +++ b/cogvideox_fun/transformer_3d.py @@ -277,6 +277,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, + add_noise_in_inpaint_model: bool = False, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -452,6 +453,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, inpaint_latents: Optional[torch.Tensor] = None, + control_latents: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): @@ -470,6 +472,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): # 2. Patch embedding if inpaint_latents is not None: hidden_states = torch.concat([hidden_states, inpaint_latents], 2) + if control_latents is not None: + hidden_states = torch.concat([hidden_states, control_latents], 2) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # 3. Position embedding diff --git a/cogvideox_fun/utils.py b/cogvideox_fun/utils.py index 0b71f55..f790161 100644 --- a/cogvideox_fun/utils.py +++ b/cogvideox_fun/utils.py @@ -159,13 +159,21 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide return input_video, input_video_mask, clip_image -def get_video_to_video_latent(input_video_path, video_length, sample_size): +def get_video_to_video_latent(input_video_path, video_length, sample_size, validation_video_mask=None): input_video = input_video_path input_video = torch.from_numpy(np.array(input_video))[:video_length] input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, :] = 255 + if validation_video_mask is not None: + validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) + input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) + + input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 return input_video, input_video_mask, None \ No newline at end of file diff --git a/nodes.py b/nodes.py index 14767b3..c2a13b0 100644 --- a/nodes.py +++ b/nodes.py @@ -52,6 +52,7 @@ from .cogvideox_fun.fun_pab_transformer_3d import CogVideoXTransformer3DModel as from .cogvideox_fun.autoencoder_magvit import AutoencoderKLCogVideoX as AutoencoderKLCogVideoXFun from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_latent, ASPECT_RATIO_512, get_closest_ratio, to_pil from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint +from .cogvideox_fun.pipeline_cogvideox_control import CogVideoX_Fun_Pipeline_Control from PIL import Image import numpy as np import json @@ -216,6 +217,10 @@ class DownloadAndLoadCogVideoModel: "bertjiazheng/KoolCogVideoX-5b", "kijai/CogVideoX-Fun-2b", "kijai/CogVideoX-Fun-5b", + "alibaba-pai/CogVideoX-Fun-V1.1-2b-InP", + "alibaba-pai/CogVideoX-Fun-V1.1-5b-InP", + "alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose", + "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", ], ), @@ -246,31 +251,38 @@ class DownloadAndLoadCogVideoModel: mm.soft_empty_cache() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] - + download_path = os.path.join(folder_paths.models_dir, "CogVideo") if "Fun" in model: - repo_id = "kijai/CogVideoX-Fun-pruned" - download_path = os.path.join(folder_paths.models_dir, "CogVideo") - if "2b" in model: - base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model - scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') + if not "1.1" in model: + repo_id = "kijai/CogVideoX-Fun-pruned" + if "2b" in model: + base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model + if not os.path.exists(base_path): + base_path = os.path.join(download_path, "CogVideoX-Fun-2b-InP") + elif "5b" in model: + base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model + if not os.path.exists(base_path): + base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP") + elif "1.1" in model: + repo_id = model + base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model if not os.path.exists(base_path): - base_path = os.path.join(download_path, "CogVideoX-Fun-2b-InP") - elif "5b" in model: - base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model - scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') - if not os.path.exists(base_path): - base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP") - + base_path = os.path.join(download_path, (model.split("/")[-1])) + download_path = base_path + elif "2b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") - scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') download_path = base_path repo_id = model elif "5b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1])) - scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') download_path = base_path repo_id = model + + if "2b" in model: + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') + elif "5b" in model: + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') if not os.path.exists(base_path): log.info(f"Downloading model to: {base_path}") @@ -323,7 +335,10 @@ class DownloadAndLoadCogVideoModel: # VAE if "Fun" in model: vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) - pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config) + if "Pose" in model: + pipe = CogVideoX_Fun_Pipeline_Control(vae, transformer, scheduler, pab_config=pab_config) + else: + pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler, pab_config=pab_config) else: vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) @@ -369,7 +384,7 @@ class DownloadAndLoadCogVideoGGUFModel: "CogVideoX_5b_GGUF_Q4_0.safetensors", "CogVideoX_5b_I2V_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_GGUF_Q4_0.safetensors", - #"CogVideoX_2b_fun_GGUF_Q4_0.safetensors" + "CogVideoX_5b_fun_1_1_GGUF_Q4_0.safetensors" ], ), "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), @@ -967,6 +982,7 @@ class CogVideoXFunSampler: "start_img": ("IMAGE",), "end_img": ("IMAGE",), "opt_empty_latent": ("LATENT",), + "noise_aug_strength": ("FLOAT", {"default": 0.0563, "min": 0.0, "max": 1.0, "step": 0.001}), }, } @@ -976,7 +992,7 @@ class CogVideoXFunSampler: CATEGORY = "CogVideoWrapper" def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, scheduler, - start_img=None, end_img=None, opt_empty_latent=None): + start_img=None, end_img=None, opt_empty_latent=None, noise_aug_strength=0.0563): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1034,6 +1050,7 @@ class CogVideoXFunSampler: video = input_video, mask_video = input_video_mask, comfyui_progressbar = True, + noise_aug_strength = noise_aug_strength, ) #if not pipeline["cpu_offloading"]: # pipe.transformer.to(offload_device) @@ -1084,8 +1101,11 @@ class CogVideoXFunVid2VidSampler: } ), "denoise_strength": ("FLOAT", {"default": 0.70, "min": 0.05, "max": 1.00, "step": 0.01}), + }, + "optional":{ "validation_video": ("IMAGE",), - } + "control_video": ("IMAGE",), + }, } RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",) @@ -1093,7 +1113,7 @@ class CogVideoXFunVid2VidSampler: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, validation_video): + def process(self, pipeline, positive, negative, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, validation_video=None, control_video=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() pipe = pipeline["pipe"] @@ -1109,8 +1129,14 @@ class CogVideoXFunVid2VidSampler: # Count most suitable height and width aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} - validation_video = np.array(validation_video.cpu().numpy() * 255, np.uint8) - original_width, original_height = Image.fromarray(validation_video[0]).size + + if validation_video is not None: + validation_video = np.array(validation_video.cpu().numpy() * 255, np.uint8) + original_width, original_height = Image.fromarray(validation_video[0]).size + elif control_video is not None: + control_video = np.array(control_video.cpu().numpy() * 255, np.uint8) + original_width, original_height = Image.fromarray(control_video[0]).size + closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) height, width = [int(x / 16) * 16 for x in closest_size] @@ -1128,26 +1154,38 @@ class CogVideoXFunVid2VidSampler: autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() with autocast_context: video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 - input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width)) + if validation_video is not None: + input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width)) + elif control_video is not None: + input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width)) # for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])): # pipeline = merge_lora(pipeline, _lora_path, _lora_weight) - latents = pipe( - prompt_embeds=positive.to(dtype).to(device), - negative_prompt_embeds=negative.to(dtype).to(device), - num_frames = video_length, - height = height, - width = width, - generator = generator, - guidance_scale = cfg, - num_inference_steps = steps, + common_params = { + "prompt_embeds": positive.to(dtype).to(device), + "negative_prompt_embeds": negative.to(dtype).to(device), + "num_frames": video_length, + "height": height, + "width": width, + "generator": generator, + "guidance_scale": cfg, + "num_inference_steps": steps, + "comfyui_progressbar": True, + } - video = input_video, - mask_video = input_video_mask, - strength = float(denoise_strength), - comfyui_progressbar = True, - ) + if control_video is not None: + latents = pipe( + **common_params, + control_video=input_video + ) + else: + latents = pipe( + **common_params, + video=input_video, + mask_video=input_video_mask, + strength=float(denoise_strength) + ) # for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])): # pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)