initial 5B support

This commit is contained in:
kijai 2024-08-27 17:06:04 +03:00
parent 8457fa7a4d
commit 7b80e61e36
3 changed files with 193 additions and 84 deletions

View File

@ -16,6 +16,12 @@ class DownloadAndLoadCogVideoModel:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"model": (
[
"THUDM/CogVideoX-2b",
"THUDM/CogVideoX-5b",
],
),
}, },
"optional": { "optional": {
@ -35,21 +41,24 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def loadmodel(self, precision): def loadmodel(self, model, precision):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") if "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideoX-5b")
if not os.path.exists(base_path): if not os.path.exists(base_path):
log.info(f"Downloading model to: {base_path}") log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( snapshot_download(
repo_id="THUDM/CogVideoX-2b", repo_id=model,
ignore_patterns=["*text_encoder*"], ignore_patterns=["*text_encoder*"],
local_dir=base_path, local_dir=base_path,
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
@ -199,14 +208,14 @@ class CogVideoSampler:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 8}), "num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 1}),
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}), "fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
"steps": ("INT", {"default": 25, "min": 1}), "steps": ("INT", {"default": 25, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"scheduler": (["DDIM", "DPM"],), "scheduler": (["DDIM", "DPM"],),
"t_tile_length": ("INT", {"default": 16, "min": 16, "max": 128, "step": 4}), "t_tile_length": ("INT", {"default": 16, "min": 2, "max": 128, "step": 1}),
"t_tile_overlap": ("INT", {"default": 8, "min": 8, "max": 128, "step": 2}), "t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1}),
}, },
"optional": { "optional": {
"samples": ("LATENT", ), "samples": ("LATENT", ),
@ -276,10 +285,10 @@ class CogVideoDecode:
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",) RETURN_NAMES = ("images",)
FUNCTION = "process" FUNCTION = "decode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def process(self, pipeline, samples): def decode(self, pipeline, samples):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
latents = samples["samples"] latents = samples["samples"]
@ -299,19 +308,20 @@ class CogVideoDecode:
frames = [] frames = []
pbar = ProgressBar(num_seconds) pbar = ProgressBar(num_seconds)
for i in range(num_seconds): # for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) # start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample # current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames) # frames.append(current_frames)
pbar.update(1) # pbar.update(1)
vae.clear_fake_context_parallel_cache() frames = vae.decode(latents).sample
vae.to(offload_device) vae.to(offload_device)
mm.soft_empty_cache() mm.soft_empty_cache()
frames = torch.cat(frames, dim=2) #frames = torch.cat(frames, dim=2)
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt") video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
video = video[0].permute(0, 2, 3, 1).cpu().float() video = video[0].permute(0, 2, 3, 1).cpu().float()
print(video.min(), video.max())
return (video,) return (video,)

View File

@ -17,6 +17,7 @@ import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import math
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
@ -24,11 +25,29 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor from diffusers.video_processor import VideoProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
@ -229,6 +248,46 @@ class CogVideoXPipeline(DiffusionPipeline):
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights return weights
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
@ -374,6 +433,15 @@ class CogVideoXPipeline(DiffusionPipeline):
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype) t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype)
print("latents.shape", latents.shape) print("latents.shape", latents.shape)
print("latents.device", latents.device) print("latents.device", latents.device)
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
comfy_pbar = ProgressBar(num_inference_steps) comfy_pbar = ProgressBar(num_inference_steps)
@ -383,94 +451,125 @@ class CogVideoXPipeline(DiffusionPipeline):
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
continue continue
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# =====================================================
grid_ts = 0
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py all_t = latents.shape[1]
# ===================================================== latents_all_list = []
grid_ts = 0 # =====================================================
cur_t = 0
while cur_t < latents.shape[1]:
cur_t = max(grid_ts * t_tile_length - t_tile_overlap * grid_ts, 0) + t_tile_length
grid_ts += 1
all_t = latents.shape[1] for t_i in range(grid_ts):
latents_all_list = [] if t_i < grid_ts - 1:
# ===================================================== ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
for t_i in range(grid_ts): input_start_t = ofs_t
if t_i < grid_ts - 1: input_end_t = ofs_t + t_tile_length
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t #latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
input_end_t = ofs_t + t_tile_length #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
latents_tile = latents[:, input_start_t:input_end_t,:, :, :] #t_input = t[None].to(device)
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
#t_input = t[None].to(device) # predict noise model_output
t_input = t.expand(latent_model_input_tile.shape[0]) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML noise_pred = self.transformer(
hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds,
timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# ==========================================
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input_tile, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=t_input, timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler): if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else: else:
raise NotImplementedError("DPM is not supported with temporal tiling") latents, old_pred_original_sample = self.scheduler.step(
# else: noise_pred,
# latents_tile, old_pred_original_sample = self.scheduler.step( old_pred_original_sample,
# noise_pred, t,
# old_pred_original_sample, timesteps[i - 1] if i > 0 else None,
# t, latents,
# t_input[t_i - 1] if t_i > 0 else None, **extra_step_kwargs,
# latents_tile, return_dict=False,
# **extra_step_kwargs, )
# return_dict=False, latents = latents.to(prompt_embeds.dtype)
# )
latents_all_list.append(latents_tile) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# ==========================================
latents_all = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
contributors = torch.zeros(latents.shape, device=latents.device, dtype=latents.dtype)
# Add each tile contribution to overall latents
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
if t_i == grid_ts - 1:
ofs_t = all_t - t_tile_length
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
latents_all[:, input_start_t:input_end_t,:, :, :] += latents_all_list[t_i] * t_tile_weights
contributors[:, input_start_t:input_end_t,:, :, :] += t_tile_weights
latents_all /= contributors
latents = latents_all
# ==========================================
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()

View File

@ -1,2 +1,2 @@
huggingface_hub huggingface_hub
diffusers>=0.30.0 diffusers>=0.30.1