From e9c6985b7d9d6fd8aacdb05a37abf2cc40c208d3 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:48:29 +0300 Subject: [PATCH] Rudimentary inpainting support, fp8_fast mode --- fp8_optimization.py | 43 ++++++++++++++++++++++++++++++++ nodes.py | 30 ++++++++++++++++++----- pipeline_cogvideox.py | 57 +++++++++++++++++++++++++++++++++++++------ 3 files changed, 117 insertions(+), 13 deletions(-) create mode 100644 fp8_optimization.py diff --git a/fp8_optimization.py b/fp8_optimization.py new file mode 100644 index 0000000..1562bf5 --- /dev/null +++ b/fp8_optimization.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from types import MethodType + +def fp8_linear_forward(cls, original_dtype, input): + weight_dtype = cls.weight.dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if len(input.shape) == 3: + if weight_dtype == torch.float8_e4m3fn: + inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2) + else: + inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn) + w = cls.weight.t() + + scale_weight = torch.ones((1), device=input.device, dtype=torch.float32) + scale_input = scale_weight + + bias = cls.bias.to(original_dtype) if cls.bias is not None else None + out_dtype = original_dtype + + if bias is not None: + o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + else: + o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight) + + if isinstance(o, tuple): + o = o[0] + + return o.reshape((-1, input.shape[1], cls.weight.shape[0])) + else: + cls.to(original_dtype) + out = cls.original_forward(input.to(original_dtype)) + cls.to(original_dtype) + return out + else: + return cls.original_forward(input) + +def convert_fp8_linear(module, original_dtype): + for name, module in module.named_modules(): + if isinstance(module, nn.Linear): + original_forward = module.forward + setattr(module, "original_forward", original_forward) + setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) diff --git a/nodes.py b/nodes.py index a1f1e3a..5da4dac 100644 --- a/nodes.py +++ b/nodes.py @@ -31,7 +31,7 @@ class DownloadAndLoadCogVideoModel: "precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), - "fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}), + "fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), } @@ -42,13 +42,13 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, precision, fp8_transformer, compile="disabled", enable_sequential_cpu_offload=False): + def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] - if fp8_transformer: + if fp8_transformer != "disabled": transformer_dtype = torch.float8_e4m3fn else: transformer_dtype = dtype @@ -69,6 +69,9 @@ class DownloadAndLoadCogVideoModel: local_dir_use_symlinks=False, ) transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device) + if fp8_transformer == "fastmode": + from .fp8_optimization import convert_fp8_linear + convert_fp8_linear(transformer, dtype) vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") @@ -92,6 +95,7 @@ class DownloadAndLoadCogVideoModel: fuse_qkv_projections=True, ) + pipeline = { "pipe": pipe, @@ -177,6 +181,7 @@ class CogVideoImageEncode: "optional": { "chunk_size": ("INT", {"default": 16, "min": 1}), "enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}), + "mask": ("MASK", ), }, } @@ -185,11 +190,15 @@ class CogVideoImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True): + def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) + + B, H, W, C = image.shape + vae = pipeline["pipe"].vae + if enable_vae_slicing: vae.enable_slicing() else: @@ -197,8 +206,17 @@ class CogVideoImageEncode: if not pipeline["cpu_offloading"]: vae.to(device) - - input_image = image.clone() * 2.0 - 1.0 + + input_image = image.clone() + if mask is not None: + pipeline["pipe"].original_mask = mask + # print(mask.shape) + # mask = mask.repeat(B, 1, 1) # Shape: [B, H, W] + # mask = mask.unsqueeze(-1).repeat(1, 1, 1, C) + # print(mask.shape) + # input_image = input_image * (1 -mask) + + input_image = input_image * 2.0 - 1.0 input_image = input_image.to(vae.dtype).to(device) input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W B, C, T, H, W = input_image.shape diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index fbd3927..b19c06c 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -17,6 +17,7 @@ import inspect from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F import math from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel @@ -138,6 +139,7 @@ class CogVideoXPipeline(DiffusionPipeline): vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + original_mask = None ): super().__init__() @@ -150,7 +152,7 @@ class CogVideoXPipeline(DiffusionPipeline): self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) - + self.original_mask = original_mask self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def prepare_latents( @@ -168,9 +170,9 @@ class CogVideoXPipeline(DiffusionPipeline): f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - + noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) + latents = noise else: latents = latents.to(device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) @@ -189,7 +191,7 @@ class CogVideoXPipeline(DiffusionPipeline): latents = self.scheduler.add_noise(latents, noise, latent_timestep) latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler - return latents, timesteps + return latents, timesteps, noise # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -419,8 +421,9 @@ class CogVideoXPipeline(DiffusionPipeline): if latents is None and num_frames == t_tile_length: num_frames += 1 - - latents, timesteps = self.prepare_latents( + image_latents = latents + original_image_latents = image_latents + latents, timesteps, noise = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, num_frames, @@ -435,6 +438,7 @@ class CogVideoXPipeline(DiffusionPipeline): latents ) latents = latents.to(self.transformer.dtype) + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -451,10 +455,25 @@ class CogVideoXPipeline(DiffusionPipeline): else None ) + # masks + if self.original_mask is not None: + mask = self.original_mask.to(device) + print("self.original_mask: ", self.original_mask.shape) + + mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False) + if mask.shape[0] != latents.shape[1]: + mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1) + else: + mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1) + print("latents: ", latents.shape) + print("mask: ", mask.shape) + # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) comfy_pbar = ProgressBar(num_inference_steps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ old_pred_original_sample = None for i, t in enumerate(timesteps): @@ -535,6 +554,19 @@ class CogVideoXPipeline(DiffusionPipeline): latents_all /= contributors latents = latents_all + #print("latents",latents.shape) + # start diff diff + if i < len(timesteps) - 1 and self.original_mask is not None: + noise_timestep = timesteps[i + 1] + image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep]) + ) + mask = mask.to(latents) + ts_from = timesteps[0] + ts_to = timesteps[-1] + threshold = (t - ts_to) / (ts_from - ts_to) + mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask)) + latents = image_latent * mask + latents * (1 - mask) + # end diff diff if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -577,7 +609,18 @@ class CogVideoXPipeline(DiffusionPipeline): **extra_step_kwargs, return_dict=False, ) - latents = latents.to(prompt_embeds.dtype) + # start diff diff + if i < len(timesteps) - 1 and self.original_mask is not None: + noise_timestep = timesteps[i + 1] + image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep]) + ) + mask = mask.to(latents) + ts_from = timesteps[0] + ts_to = timesteps[-1] + threshold = (t - ts_to) / (ts_from - ts_to) + mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask)) + latents = image_latent * mask + latents * (1 - mask) + # end diff diff if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update()