mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 20:14:27 +08:00
Rudimentary inpainting support, fp8_fast mode
This commit is contained in:
parent
bbbb6e9514
commit
e9c6985b7d
43
fp8_optimization.py
Normal file
43
fp8_optimization.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import MethodType
|
||||
|
||||
def fp8_linear_forward(cls, original_dtype, input):
|
||||
weight_dtype = cls.weight.dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
if len(input.shape) == 3:
|
||||
if weight_dtype == torch.float8_e4m3fn:
|
||||
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
|
||||
else:
|
||||
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
|
||||
w = cls.weight.t()
|
||||
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
scale_input = scale_weight
|
||||
|
||||
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
|
||||
out_dtype = original_dtype
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
||||
else:
|
||||
cls.to(original_dtype)
|
||||
out = cls.original_forward(input.to(original_dtype))
|
||||
cls.to(original_dtype)
|
||||
return out
|
||||
else:
|
||||
return cls.original_forward(input)
|
||||
|
||||
def convert_fp8_linear(module, original_dtype):
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
original_forward = module.forward
|
||||
setattr(module, "original_forward", original_forward)
|
||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||
30
nodes.py
30
nodes.py
@ -31,7 +31,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"precision": (["fp16", "fp32", "bf16"],
|
||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||
),
|
||||
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}),
|
||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||
}
|
||||
@ -42,13 +42,13 @@ class DownloadAndLoadCogVideoModel:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def loadmodel(self, model, precision, fp8_transformer, compile="disabled", enable_sequential_cpu_offload=False):
|
||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
if fp8_transformer:
|
||||
if fp8_transformer != "disabled":
|
||||
transformer_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
transformer_dtype = dtype
|
||||
@ -69,6 +69,9 @@ class DownloadAndLoadCogVideoModel:
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
||||
if fp8_transformer == "fastmode":
|
||||
from .fp8_optimization import convert_fp8_linear
|
||||
convert_fp8_linear(transformer, dtype)
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
|
||||
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||
@ -92,6 +95,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
fuse_qkv_projections=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
@ -177,6 +181,7 @@ class CogVideoImageEncode:
|
||||
"optional": {
|
||||
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
||||
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
|
||||
"mask": ("MASK", ),
|
||||
},
|
||||
}
|
||||
|
||||
@ -185,11 +190,15 @@ class CogVideoImageEncode:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True):
|
||||
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
B, H, W, C = image.shape
|
||||
|
||||
vae = pipeline["pipe"].vae
|
||||
|
||||
if enable_vae_slicing:
|
||||
vae.enable_slicing()
|
||||
else:
|
||||
@ -197,8 +206,17 @@ class CogVideoImageEncode:
|
||||
|
||||
if not pipeline["cpu_offloading"]:
|
||||
vae.to(device)
|
||||
|
||||
input_image = image.clone() * 2.0 - 1.0
|
||||
|
||||
input_image = image.clone()
|
||||
if mask is not None:
|
||||
pipeline["pipe"].original_mask = mask
|
||||
# print(mask.shape)
|
||||
# mask = mask.repeat(B, 1, 1) # Shape: [B, H, W]
|
||||
# mask = mask.unsqueeze(-1).repeat(1, 1, 1, C)
|
||||
# print(mask.shape)
|
||||
# input_image = input_image * (1 -mask)
|
||||
|
||||
input_image = input_image * 2.0 - 1.0
|
||||
input_image = input_image.to(vae.dtype).to(device)
|
||||
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||
B, C, T, H, W = input_image.shape
|
||||
|
||||
@ -17,6 +17,7 @@ import inspect
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
@ -138,6 +139,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
original_mask = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -150,7 +152,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
)
|
||||
|
||||
self.original_mask = original_mask
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def prepare_latents(
|
||||
@ -168,9 +170,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||
latents = noise
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||
@ -189,7 +191,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
|
||||
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
|
||||
return latents, timesteps
|
||||
return latents, timesteps, noise
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@ -419,8 +421,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
if latents is None and num_frames == t_tile_length:
|
||||
num_frames += 1
|
||||
|
||||
latents, timesteps = self.prepare_latents(
|
||||
image_latents = latents
|
||||
original_image_latents = image_latents
|
||||
latents, timesteps, noise = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
@ -435,6 +438,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
latents
|
||||
)
|
||||
latents = latents.to(self.transformer.dtype)
|
||||
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
@ -451,10 +455,25 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
else None
|
||||
)
|
||||
|
||||
# masks
|
||||
if self.original_mask is not None:
|
||||
mask = self.original_mask.to(device)
|
||||
print("self.original_mask: ", self.original_mask.shape)
|
||||
|
||||
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
|
||||
if mask.shape[0] != latents.shape[1]:
|
||||
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
|
||||
else:
|
||||
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
|
||||
print("latents: ", latents.shape)
|
||||
print("mask: ", mask.shape)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
comfy_pbar = ProgressBar(num_inference_steps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
@ -535,6 +554,19 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
latents_all /= contributors
|
||||
|
||||
latents = latents_all
|
||||
#print("latents",latents.shape)
|
||||
# start diff diff
|
||||
if i < len(timesteps) - 1 and self.original_mask is not None:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
mask = mask.to(latents)
|
||||
ts_from = timesteps[0]
|
||||
ts_to = timesteps[-1]
|
||||
threshold = (t - ts_to) / (ts_from - ts_to)
|
||||
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
|
||||
latents = image_latent * mask + latents * (1 - mask)
|
||||
# end diff diff
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@ -577,7 +609,18 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
# start diff diff
|
||||
if i < len(timesteps) - 1 and self.original_mask is not None:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
mask = mask.to(latents)
|
||||
ts_from = timesteps[0]
|
||||
ts_to = timesteps[-1]
|
||||
threshold = (t - ts_to) / (ts_from - ts_to)
|
||||
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
|
||||
latents = image_latent * mask + latents * (1 - mask)
|
||||
# end diff diff
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user