From 7b80e61e36f7c5dead1cef3adfe6717e91f3ae9d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:06:04 +0300 Subject: [PATCH] initial 5B support --- nodes.py | 40 ++++--- pipeline_cogvideox.py | 235 ++++++++++++++++++++++++++++++------------ requirements.txt | 2 +- 3 files changed, 193 insertions(+), 84 deletions(-) diff --git a/nodes.py b/nodes.py index 7bd8186..4ca561a 100644 --- a/nodes.py +++ b/nodes.py @@ -16,6 +16,12 @@ class DownloadAndLoadCogVideoModel: def INPUT_TYPES(s): return { "required": { + "model": ( + [ + "THUDM/CogVideoX-2b", + "THUDM/CogVideoX-5b", + ], + ), }, "optional": { @@ -35,21 +41,24 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, precision): + def loadmodel(self, model, precision): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] - base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") + if "2b" in model: + base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") + elif "5b" in model: + base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideoX-5b") if not os.path.exists(base_path): log.info(f"Downloading model to: {base_path}") from huggingface_hub import snapshot_download snapshot_download( - repo_id="THUDM/CogVideoX-2b", + repo_id=model, ignore_patterns=["*text_encoder*"], local_dir=base_path, local_dir_use_symlinks=False, @@ -199,14 +208,14 @@ class CogVideoSampler: "negative": ("CONDITIONING", ), "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), - "num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 8}), + "num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 1}), "fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}), "steps": ("INT", {"default": 25, "min": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "scheduler": (["DDIM", "DPM"],), - "t_tile_length": ("INT", {"default": 16, "min": 16, "max": 128, "step": 4}), - "t_tile_overlap": ("INT", {"default": 8, "min": 8, "max": 128, "step": 2}), + "t_tile_length": ("INT", {"default": 16, "min": 2, "max": 128, "step": 1}), + "t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1}), }, "optional": { "samples": ("LATENT", ), @@ -276,10 +285,10 @@ class CogVideoDecode: RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) - FUNCTION = "process" + FUNCTION = "decode" CATEGORY = "CogVideoWrapper" - def process(self, pipeline, samples): + def decode(self, pipeline, samples): device = mm.get_torch_device() offload_device = mm.unet_offload_device() latents = samples["samples"] @@ -299,19 +308,20 @@ class CogVideoDecode: frames = [] pbar = ProgressBar(num_seconds) - for i in range(num_seconds): - start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) - current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample - frames.append(current_frames) + # for i in range(num_seconds): + # start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) + # current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample + # frames.append(current_frames) - pbar.update(1) - vae.clear_fake_context_parallel_cache() + # pbar.update(1) + frames = vae.decode(latents).sample vae.to(offload_device) mm.soft_empty_cache() - frames = torch.cat(frames, dim=2) + #frames = torch.cat(frames, dim=2) video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt") video = video[0].permute(0, 2, 3, 1).cpu().float() + print(video.min(), video.max()) return (video,) diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index b36846a..edceb4c 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -17,6 +17,7 @@ import inspect from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import math from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -24,11 +25,29 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor +from diffusers.models.embeddings import get_3d_rotary_pos_embed from comfy.utils import ProgressBar logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -228,6 +247,46 @@ class CogVideoXPipeline(DiffusionPipeline): weights = torch.tensor(t_probs) weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) return weights + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + use_real=True, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin @property def guidance_scale(self): @@ -374,6 +433,15 @@ class CogVideoXPipeline(DiffusionPipeline): t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype) print("latents.shape", latents.shape) print("latents.device", latents.device) + + + # 6.5. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) comfy_pbar = ProgressBar(num_inference_steps) @@ -383,94 +451,125 @@ class CogVideoXPipeline(DiffusionPipeline): for i, t in enumerate(timesteps): if self.interrupt: continue + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py + # ===================================================== + grid_ts = 0 + cur_t = 0 + while cur_t < latents.shape[1]: + cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length + grid_ts += 1 + + all_t = latents.shape[1] + latents_all_list = [] + # ===================================================== + + for t_i in range(grid_ts): + if t_i < grid_ts - 1: + ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) + if t_i == grid_ts - 1: + ofs_t = all_t - t_tile_length + + input_start_t = ofs_t + input_end_t = ofs_t + t_tile_length + + #latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latents_tile = latents[:, input_start_t:input_end_t,:, :, :] + latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile + latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) + + #t_input = t[None].to(device) + t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py - # ===================================================== - grid_ts = 0 - cur_t = 0 - while cur_t < latents.shape[1]: - cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length - grid_ts += 1 + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input_tile, + encoder_hidden_states=prompt_embeds, + timestep=t_input, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.float() - all_t = latents.shape[1] - latents_all_list = [] - # ===================================================== + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length + # compute the previous noisy sample x_t -> x_t-1 + latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0] + latents_all_list.append(latents_tile) - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length + # ========================================== + latents_all = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype) + contributors = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype) + # Add each tile contribution to overall latents + for t_i in range(grid_ts): + if t_i < grid_ts - 1: + ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) + if t_i == grid_ts - 1: + ofs_t = all_t - t_tile_length - #latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + input_start_t = ofs_t + input_end_t = ofs_t + t_tile_length - latents_tile = latents[:, input_start_t:input_end_t,:, :, :] - latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile - latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t) + latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights + contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights + + latents_all /= contributors + + latents = latents_all + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + comfy_pbar.update(1) + # ========================================== + else: + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) - #t_input = t[None].to(device) - t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - # predict noise model_output noise_pred = self.transformer( - hidden_states=latent_model_input_tile, + hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, - timestep=t_input, + timestep=timestep, + image_rotary_emb=image_rotary_emb, return_dict=False, )[0] - noise_pred = noise_pred.float() + noise_pred = noise_pred.float() - if self.do_classifier_free_guidance: + + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + + if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] else: - raise NotImplementedError("DPM is not supported with temporal tiling") - # else: - # latents_tile, old_pred_original_sample = self.scheduler.step( - # noise_pred, - # old_pred_original_sample, - # t, - # t_input[t_i - 1] if t_i > 0 else None, - # latents_tile, - # **extra_step_kwargs, - # return_dict=False, - # ) - - latents_all_list.append(latents_tile) + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) - # ========================================== - latents_all = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype) - contributors = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype) - # Add each tile contribution to overall latents - for t_i in range(grid_ts): - if t_i < grid_ts - 1: - ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0) - if t_i == grid_ts - 1: - ofs_t = all_t - t_tile_length - - input_start_t = ofs_t - input_end_t = ofs_t + t_tile_length - - latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights - contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights - - latents_all /= contributors - - latents = latents_all - # ========================================== - - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - comfy_pbar.update(1) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + comfy_pbar.update(1) + # Offload all models self.maybe_free_model_hooks() diff --git a/requirements.txt b/requirements.txt index a4e18bd..db806f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ huggingface_hub -diffusers>=0.30.0 \ No newline at end of file +diffusers>=0.30.1 \ No newline at end of file