Rudimentary inpainting support, fp8_fast mode

This commit is contained in:
kijai 2024-09-01 19:48:29 +03:00
parent bbbb6e9514
commit e9c6985b7d
3 changed files with 117 additions and 13 deletions

43
fp8_optimization.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
from types import MethodType
def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if len(input.shape) == 3:
if weight_dtype == torch.float8_e4m3fn:
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
else:
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
w = cls.weight.t()
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
scale_input = scale_weight
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
out_dtype = original_dtype
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
else:
cls.to(original_dtype)
out = cls.original_forward(input.to(original_dtype))
cls.to(original_dtype)
return out
else:
return cls.original_forward(input)
def convert_fp8_linear(module, original_dtype):
for name, module in module.named_modules():
if isinstance(module, nn.Linear):
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -31,7 +31,7 @@ class DownloadAndLoadCogVideoModel:
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}),
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
}
@ -42,13 +42,13 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision, fp8_transformer, compile="disabled", enable_sequential_cpu_offload=False):
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if fp8_transformer:
if fp8_transformer != "disabled":
transformer_dtype = torch.float8_e4m3fn
else:
transformer_dtype = dtype
@ -69,6 +69,9 @@ class DownloadAndLoadCogVideoModel:
local_dir_use_symlinks=False,
)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, dtype)
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
@ -92,6 +95,7 @@ class DownloadAndLoadCogVideoModel:
fuse_qkv_projections=True,
)
pipeline = {
"pipe": pipe,
@ -177,6 +181,7 @@ class CogVideoImageEncode:
"optional": {
"chunk_size": ("INT", {"default": 16, "min": 1}),
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
"mask": ("MASK", ),
},
}
@ -185,11 +190,15 @@ class CogVideoImageEncode:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True):
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
B, H, W, C = image.shape
vae = pipeline["pipe"].vae
if enable_vae_slicing:
vae.enable_slicing()
else:
@ -197,8 +206,17 @@ class CogVideoImageEncode:
if not pipeline["cpu_offloading"]:
vae.to(device)
input_image = image.clone() * 2.0 - 1.0
input_image = image.clone()
if mask is not None:
pipeline["pipe"].original_mask = mask
# print(mask.shape)
# mask = mask.repeat(B, 1, 1) # Shape: [B, H, W]
# mask = mask.unsqueeze(-1).repeat(1, 1, 1, C)
# print(mask.shape)
# input_image = input_image * (1 -mask)
input_image = input_image * 2.0 - 1.0
input_image = input_image.to(vae.dtype).to(device)
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
B, C, T, H, W = input_image.shape

View File

@ -17,6 +17,7 @@ import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import math
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
@ -138,6 +139,7 @@ class CogVideoXPipeline(DiffusionPipeline):
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
original_mask = None
):
super().__init__()
@ -150,7 +152,7 @@ class CogVideoXPipeline(DiffusionPipeline):
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.original_mask = original_mask
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def prepare_latents(
@ -168,9 +170,9 @@ class CogVideoXPipeline(DiffusionPipeline):
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
latents = noise
else:
latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
@ -189,7 +191,7 @@ class CogVideoXPipeline(DiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
return latents, timesteps
return latents, timesteps, noise
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@ -419,8 +421,9 @@ class CogVideoXPipeline(DiffusionPipeline):
if latents is None and num_frames == t_tile_length:
num_frames += 1
latents, timesteps = self.prepare_latents(
image_latents = latents
original_image_latents = image_latents
latents, timesteps, noise = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
@ -435,6 +438,7 @@ class CogVideoXPipeline(DiffusionPipeline):
latents
)
latents = latents.to(self.transformer.dtype)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@ -451,10 +455,25 @@ class CogVideoXPipeline(DiffusionPipeline):
else None
)
# masks
if self.original_mask is not None:
mask = self.original_mask.to(device)
print("self.original_mask: ", self.original_mask.shape)
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
if mask.shape[0] != latents.shape[1]:
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
else:
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
print("latents: ", latents.shape)
print("mask: ", mask.shape)
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
comfy_pbar = ProgressBar(num_inference_steps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
@ -535,6 +554,19 @@ class CogVideoXPipeline(DiffusionPipeline):
latents_all /= contributors
latents = latents_all
#print("latents",latents.shape)
# start diff diff
if i < len(timesteps) - 1 and self.original_mask is not None:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
)
mask = mask.to(latents)
ts_from = timesteps[0]
ts_to = timesteps[-1]
threshold = (t - ts_to) / (ts_from - ts_to)
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@ -577,7 +609,18 @@ class CogVideoXPipeline(DiffusionPipeline):
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# start diff diff
if i < len(timesteps) - 1 and self.original_mask is not None:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
)
mask = mask.to(latents)
ts_from = timesteps[0]
ts_to = timesteps[-1]
threshold = (t - ts_to) / (ts_from - ts_to)
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()