diff --git a/nodes.py b/nodes.py index b99adf8..584fc17 100644 --- a/nodes.py +++ b/nodes.py @@ -122,21 +122,82 @@ class CogVideoTextEncode: embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False) return (embeds, ) - -class CogVideoSampler: + +class CogVideoImageEncode: @classmethod def INPUT_TYPES(s): return {"required": { "pipeline": ("COGVIDEOPIPE",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), - "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), - "num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}), - "fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}), - "steps": ("INT", {"default": 25, "min": 1}), - "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "image": ("IMAGE", ), + }, + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, pipeline, image): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + generator = torch.Generator(device=device).manual_seed(0) + vae = pipeline["pipe"].vae + vae.to(device) + + image = image * 2.0 - 1.0 + image = image.to(vae.dtype).to(device) + image = image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + B, C, T, H, W = image.shape + chunk_size = 16 + latents_list = [] + # Loop through the temporal dimension in chunks of 16 + for i in range(0, T, chunk_size): + # Get the chunk of 16 frames (or remaining frames if less than 16 are left) + end_index = min(i + chunk_size, T) + image_chunk = image[:, :, i:end_index, :, :] # Shape: [B, C, chunk_size, H, W] + + # Encode the chunk of images + latents = vae.encode(image_chunk) + + sample_mode = "sample" + if hasattr(latents, "latent_dist") and sample_mode == "sample": + latents = latents.latent_dist.sample(generator) + elif hasattr(latents, "latent_dist") and sample_mode == "argmax": + latents = latents.latent_dist.mode() + elif hasattr(latents, "latents"): + latents = latents.latents + + latents = vae.config.scaling_factor * latents + latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W + latents_list.append(latents) + + # Concatenate all the chunks along the temporal dimension + final_latents = torch.cat(latents_list, dim=1) + print("final latents: ", final_latents.shape) + + vae.to(offload_device) + + return ({"samples": final_latents}, ) + +class CogVideoSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "pipeline": ("COGVIDEOPIPE",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), + "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), + "num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}), + "fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}), + "steps": ("INT", {"default": 25, "min": 1}), + "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + }, + "optional": { + "samples": ("LATENT", ), + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } @@ -145,7 +206,7 @@ class CogVideoSampler: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames): + def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames, samples=None, denoise_strength=1.0): mm.soft_empty_cache() device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -162,6 +223,8 @@ class CogVideoSampler: num_frames = num_frames, fps = fps, guidance_scale=cfg, + latents=samples["samples"] if samples is not None else None, + denoise_strength=denoise_strength, prompt_embeds=positive.to(dtype).to(device), negative_prompt_embeds=negative.to(dtype).to(device), #negative_prompt_embeds=torch.zeros_like(embeds), @@ -198,11 +261,15 @@ class CogVideoDecode: vae = pipeline["pipe"].vae vae.to(device) - num_frames = pipeline["num_frames"] - fps = pipeline["fps"] + if "num_frames" in pipeline: + num_frames = pipeline["num_frames"] + fps = pipeline["fps"] + + else: + num_frames = latents.shape[2] + fps = 8 num_seconds = num_frames // fps - latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / vae.config.scaling_factor * latents @@ -217,6 +284,7 @@ class CogVideoDecode: vae.to(offload_device) frames = torch.cat(frames, dim=2) + print(frames.min(), frames.max()) video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt") print(video.shape) video = video[0].permute(0, 2, 3, 1).cpu().float() @@ -229,11 +297,13 @@ NODE_CLASS_MAPPINGS = { "DownloadAndLoadCogVideoModel": DownloadAndLoadCogVideoModel, "CogVideoSampler": CogVideoSampler, "CogVideoDecode": CogVideoDecode, - "CogVideoTextEncode": CogVideoTextEncode + "CogVideoTextEncode": CogVideoTextEncode, + "CogVideoImageEncode": CogVideoImageEncode } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", "CogVideoSampler": "CogVideo Sampler", "CogVideoDecode": "CogVideo Decode", - "CogVideoTextEncode": "CogVideo TextEncode" + "CogVideoTextEncode": "CogVideo TextEncode", + "CogVideoImageEncode": "CogVideo ImageEncode" } \ No newline at end of file diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index d9fea75..4383322 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import T5EncoderModel, T5Tokenizer from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel @@ -165,8 +164,6 @@ class CogVideoXPipeline(DiffusionPipeline): def __init__( self, - tokenizer: T5Tokenizer, - #text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], @@ -174,7 +171,7 @@ class CogVideoXPipeline(DiffusionPipeline): super().__init__() self.register_modules( - tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler + vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -182,136 +179,11 @@ class CogVideoXPipeline(DiffusionPipeline): self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226 - ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - #prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - return prompt_embeds, negative_prompt_embeds - def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, ): shape = ( batch_size, @@ -328,12 +200,27 @@ class CogVideoXPipeline(DiffusionPipeline): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # scale the initial noise by the standard deviation required by the scheduler + else: latents = latents.to(device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) + latent_timestep = timesteps[:1] + + noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype) + frames_needed = noise.shape[1] + current_frames = latents.shape[1] + + if frames_needed > current_frames: + repeat_factor = frames_needed // current_frames + additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device) + latents = torch.cat((latents, additional_frame), dim=1) + elif frames_needed < current_frames: + latents = latents[:, :frames_needed, :, :, :] - # scale the initial noise by the standard deviation required by the scheduler + latents = self.scheduler.add_noise(latents, noise, latent_timestep) latents = latents * self.scheduler.init_noise_sigma - return latents + return latents, timesteps def decode_latents(self, latents: torch.Tensor, num_seconds: int): latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] @@ -372,10 +259,8 @@ class CogVideoXPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( self, - prompt, height, width, - negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, negative_prompt_embeds=None, @@ -389,29 +274,6 @@ class CogVideoXPipeline(DiffusionPipeline): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: @@ -420,6 +282,16 @@ class CogVideoXPipeline(DiffusionPipeline): f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps.to(device), num_inference_steps - t_start @property def guidance_scale(self): @@ -444,8 +316,6 @@ class CogVideoXPipeline(DiffusionPipeline): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, num_frames: int = 48, @@ -453,6 +323,7 @@ class CogVideoXPipeline(DiffusionPipeline): num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, + denoise_strength: float = 1.0, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -553,10 +424,8 @@ class CogVideoXPipeline(DiffusionPipeline): # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, - negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, negative_prompt_embeds, @@ -565,12 +434,8 @@ class CogVideoXPipeline(DiffusionPipeline): self._interrupt = False # 2. Default call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + + batch_size = prompt_embeds.shape[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -587,7 +452,7 @@ class CogVideoXPipeline(DiffusionPipeline): # 5. Prepare latents. latent_channels = self.transformer.config.in_channels num_frames += 1 - latents = self.prepare_latents( + latents, timesteps = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, num_frames, @@ -596,7 +461,10 @@ class CogVideoXPipeline(DiffusionPipeline): prompt_embeds.dtype, device, generator, - latents, + timesteps, + denoise_strength, + num_inference_steps, + latents ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline