mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-02 03:12:20 +08:00
Rudimentary inpainting support, fp8_fast mode
This commit is contained in:
parent
bbbb6e9514
commit
e9c6985b7d
43
fp8_optimization.py
Normal file
43
fp8_optimization.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
def fp8_linear_forward(cls, original_dtype, input):
|
||||||
|
weight_dtype = cls.weight.dtype
|
||||||
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
if len(input.shape) == 3:
|
||||||
|
if weight_dtype == torch.float8_e4m3fn:
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
|
||||||
|
else:
|
||||||
|
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e4m3fn)
|
||||||
|
w = cls.weight.t()
|
||||||
|
|
||||||
|
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
scale_input = scale_weight
|
||||||
|
|
||||||
|
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
|
||||||
|
out_dtype = original_dtype
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
else:
|
||||||
|
o = torch._scaled_mm(inn, w, out_dtype=out_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
o = o[0]
|
||||||
|
|
||||||
|
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
||||||
|
else:
|
||||||
|
cls.to(original_dtype)
|
||||||
|
out = cls.original_forward(input.to(original_dtype))
|
||||||
|
cls.to(original_dtype)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return cls.original_forward(input)
|
||||||
|
|
||||||
|
def convert_fp8_linear(module, original_dtype):
|
||||||
|
for name, module in module.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
original_forward = module.forward
|
||||||
|
setattr(module, "original_forward", original_forward)
|
||||||
|
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||||
30
nodes.py
30
nodes.py
@ -31,7 +31,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"precision": (["fp16", "fp32", "bf16"],
|
"precision": (["fp16", "fp32", "bf16"],
|
||||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||||
),
|
),
|
||||||
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}),
|
||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
}
|
}
|
||||||
@ -42,13 +42,13 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer, compile="disabled", enable_sequential_cpu_offload=False):
|
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
if fp8_transformer:
|
if fp8_transformer != "disabled":
|
||||||
transformer_dtype = torch.float8_e4m3fn
|
transformer_dtype = torch.float8_e4m3fn
|
||||||
else:
|
else:
|
||||||
transformer_dtype = dtype
|
transformer_dtype = dtype
|
||||||
@ -69,6 +69,9 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
||||||
|
if fp8_transformer == "fastmode":
|
||||||
|
from .fp8_optimization import convert_fp8_linear
|
||||||
|
convert_fp8_linear(transformer, dtype)
|
||||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
|
|
||||||
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||||
@ -92,6 +95,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
fuse_qkv_projections=True,
|
fuse_qkv_projections=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pipeline = {
|
pipeline = {
|
||||||
"pipe": pipe,
|
"pipe": pipe,
|
||||||
@ -177,6 +181,7 @@ class CogVideoImageEncode:
|
|||||||
"optional": {
|
"optional": {
|
||||||
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
"chunk_size": ("INT", {"default": 16, "min": 1}),
|
||||||
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
|
"enable_vae_slicing": ("BOOLEAN", {"default": True, "tooltip": "VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes."}),
|
||||||
|
"mask": ("MASK", ),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,11 +190,15 @@ class CogVideoImageEncode:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True):
|
def encode(self, pipeline, image, chunk_size=8, enable_vae_slicing=True, mask=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
|
||||||
|
B, H, W, C = image.shape
|
||||||
|
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
|
|
||||||
if enable_vae_slicing:
|
if enable_vae_slicing:
|
||||||
vae.enable_slicing()
|
vae.enable_slicing()
|
||||||
else:
|
else:
|
||||||
@ -197,8 +206,17 @@ class CogVideoImageEncode:
|
|||||||
|
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
input_image = image.clone() * 2.0 - 1.0
|
input_image = image.clone()
|
||||||
|
if mask is not None:
|
||||||
|
pipeline["pipe"].original_mask = mask
|
||||||
|
# print(mask.shape)
|
||||||
|
# mask = mask.repeat(B, 1, 1) # Shape: [B, H, W]
|
||||||
|
# mask = mask.unsqueeze(-1).repeat(1, 1, 1, C)
|
||||||
|
# print(mask.shape)
|
||||||
|
# input_image = input_image * (1 -mask)
|
||||||
|
|
||||||
|
input_image = input_image * 2.0 - 1.0
|
||||||
input_image = input_image.to(vae.dtype).to(device)
|
input_image = input_image.to(vae.dtype).to(device)
|
||||||
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||||
B, C, T, H, W = input_image.shape
|
B, C, T, H, W = input_image.shape
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import inspect
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||||
@ -138,6 +139,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
vae: AutoencoderKLCogVideoX,
|
vae: AutoencoderKLCogVideoX,
|
||||||
transformer: CogVideoXTransformer3DModel,
|
transformer: CogVideoXTransformer3DModel,
|
||||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
|
original_mask = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -150,7 +152,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
self.vae_scale_factor_temporal = (
|
self.vae_scale_factor_temporal = (
|
||||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||||
)
|
)
|
||||||
|
self.original_mask = original_mask
|
||||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
@ -168,9 +170,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
)
|
)
|
||||||
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
latents = noise
|
||||||
else:
|
else:
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||||
@ -189,7 +191,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
|
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
|
||||||
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
|
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
|
||||||
return latents, timesteps
|
return latents, timesteps, noise
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
def prepare_extra_step_kwargs(self, generator, eta):
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
@ -419,8 +421,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if latents is None and num_frames == t_tile_length:
|
if latents is None and num_frames == t_tile_length:
|
||||||
num_frames += 1
|
num_frames += 1
|
||||||
|
image_latents = latents
|
||||||
latents, timesteps = self.prepare_latents(
|
original_image_latents = image_latents
|
||||||
|
latents, timesteps, noise = self.prepare_latents(
|
||||||
batch_size * num_videos_per_prompt,
|
batch_size * num_videos_per_prompt,
|
||||||
latent_channels,
|
latent_channels,
|
||||||
num_frames,
|
num_frames,
|
||||||
@ -435,6 +438,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
latents
|
latents
|
||||||
)
|
)
|
||||||
latents = latents.to(self.transformer.dtype)
|
latents = latents.to(self.transformer.dtype)
|
||||||
|
|
||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
@ -451,10 +455,25 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# masks
|
||||||
|
if self.original_mask is not None:
|
||||||
|
mask = self.original_mask.to(device)
|
||||||
|
print("self.original_mask: ", self.original_mask.shape)
|
||||||
|
|
||||||
|
mask = F.interpolate(self.original_mask.unsqueeze(1), size=(latents.shape[-2], latents.shape[-1]), mode='bilinear', align_corners=False)
|
||||||
|
if mask.shape[0] != latents.shape[1]:
|
||||||
|
mask = mask.unsqueeze(1).repeat(1, latents.shape[1], 16, 1, 1)
|
||||||
|
else:
|
||||||
|
mask = mask.unsqueeze(0).repeat(1, 1, 16, 1, 1)
|
||||||
|
print("latents: ", latents.shape)
|
||||||
|
print("mask: ", mask.shape)
|
||||||
|
|
||||||
# 7. Denoising loop
|
# 7. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
comfy_pbar = ProgressBar(num_inference_steps)
|
comfy_pbar = ProgressBar(num_inference_steps)
|
||||||
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
|
||||||
# for DPM-solver++
|
# for DPM-solver++
|
||||||
old_pred_original_sample = None
|
old_pred_original_sample = None
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
@ -535,6 +554,19 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
latents_all /= contributors
|
latents_all /= contributors
|
||||||
|
|
||||||
latents = latents_all
|
latents = latents_all
|
||||||
|
#print("latents",latents.shape)
|
||||||
|
# start diff diff
|
||||||
|
if i < len(timesteps) - 1 and self.original_mask is not None:
|
||||||
|
noise_timestep = timesteps[i + 1]
|
||||||
|
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
|
||||||
|
)
|
||||||
|
mask = mask.to(latents)
|
||||||
|
ts_from = timesteps[0]
|
||||||
|
ts_to = timesteps[-1]
|
||||||
|
threshold = (t - ts_to) / (ts_from - ts_to)
|
||||||
|
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
|
||||||
|
latents = image_latent * mask + latents * (1 - mask)
|
||||||
|
# end diff diff
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
@ -577,7 +609,18 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
**extra_step_kwargs,
|
**extra_step_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
latents = latents.to(prompt_embeds.dtype)
|
# start diff diff
|
||||||
|
if i < len(timesteps) - 1 and self.original_mask is not None:
|
||||||
|
noise_timestep = timesteps[i + 1]
|
||||||
|
image_latent = self.scheduler.add_noise(original_image_latents, noise, torch.tensor([noise_timestep])
|
||||||
|
)
|
||||||
|
mask = mask.to(latents)
|
||||||
|
ts_from = timesteps[0]
|
||||||
|
ts_to = timesteps[-1]
|
||||||
|
threshold = (t - ts_to) / (ts_from - ts_to)
|
||||||
|
mask = torch.where(mask >= threshold, mask, torch.zeros_like(mask))
|
||||||
|
latents = image_latent * mask + latents * (1 - mask)
|
||||||
|
# end diff diff
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user