initial 5B support

This commit is contained in:
kijai 2024-08-27 17:06:04 +03:00
parent 8457fa7a4d
commit 7b80e61e36
3 changed files with 193 additions and 84 deletions

View File

@ -16,6 +16,12 @@ class DownloadAndLoadCogVideoModel:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"model": (
[
"THUDM/CogVideoX-2b",
"THUDM/CogVideoX-5b",
],
),
}, },
"optional": { "optional": {
@ -35,21 +41,24 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def loadmodel(self, precision): def loadmodel(self, model, precision):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
if "2b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
elif "5b" in model:
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideoX-5b")
if not os.path.exists(base_path): if not os.path.exists(base_path):
log.info(f"Downloading model to: {base_path}") log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( snapshot_download(
repo_id="THUDM/CogVideoX-2b", repo_id=model,
ignore_patterns=["*text_encoder*"], ignore_patterns=["*text_encoder*"],
local_dir=base_path, local_dir=base_path,
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
@ -199,14 +208,14 @@ class CogVideoSampler:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}), "width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 8}), "num_frames": ("INT", {"default": 48, "min": 8, "max": 1024, "step": 1}),
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}), "fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
"steps": ("INT", {"default": 25, "min": 1}), "steps": ("INT", {"default": 25, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"scheduler": (["DDIM", "DPM"],), "scheduler": (["DDIM", "DPM"],),
"t_tile_length": ("INT", {"default": 16, "min": 16, "max": 128, "step": 4}), "t_tile_length": ("INT", {"default": 16, "min": 2, "max": 128, "step": 1}),
"t_tile_overlap": ("INT", {"default": 8, "min": 8, "max": 128, "step": 2}), "t_tile_overlap": ("INT", {"default": 8, "min": 2, "max": 128, "step": 1}),
}, },
"optional": { "optional": {
"samples": ("LATENT", ), "samples": ("LATENT", ),
@ -276,10 +285,10 @@ class CogVideoDecode:
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",) RETURN_NAMES = ("images",)
FUNCTION = "process" FUNCTION = "decode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def process(self, pipeline, samples): def decode(self, pipeline, samples):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
latents = samples["samples"] latents = samples["samples"]
@ -299,19 +308,20 @@ class CogVideoDecode:
frames = [] frames = []
pbar = ProgressBar(num_seconds) pbar = ProgressBar(num_seconds)
for i in range(num_seconds): # for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) # start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample # current_frames = vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames) # frames.append(current_frames)
pbar.update(1) # pbar.update(1)
vae.clear_fake_context_parallel_cache() frames = vae.decode(latents).sample
vae.to(offload_device) vae.to(offload_device)
mm.soft_empty_cache() mm.soft_empty_cache()
frames = torch.cat(frames, dim=2) #frames = torch.cat(frames, dim=2)
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt") video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
video = video[0].permute(0, 2, 3, 1).cpu().float() video = video[0].permute(0, 2, 3, 1).cpu().float()
print(video.min(), video.max())
return (video,) return (video,)

View File

@ -17,6 +17,7 @@ import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import math
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
@ -24,11 +25,29 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor from diffusers.video_processor import VideoProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
@ -229,6 +248,46 @@ class CogVideoXPipeline(DiffusionPipeline):
weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1) weights = weights.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, t_batch_size,1, 1, 1)
return weights return weights
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
@ -374,6 +433,15 @@ class CogVideoXPipeline(DiffusionPipeline):
t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype) t_tile_weights = self._gaussian_weights(t_tile_length=t_tile_length, t_batch_size=1).to(latents.device).to(latents.dtype)
print("latents.shape", latents.shape) print("latents.shape", latents.shape)
print("latents.device", latents.device) print("latents.device", latents.device)
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
comfy_pbar = ProgressBar(num_inference_steps) comfy_pbar = ProgressBar(num_inference_steps)
@ -383,7 +451,7 @@ class CogVideoXPipeline(DiffusionPipeline):
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
continue continue
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
#temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py #temporal tiling code based on https://github.com/mayuelala/FollowYourEmoji/blob/main/models/video_pipeline.py
# ===================================================== # =====================================================
grid_ts = 0 grid_ts = 0
@ -420,6 +488,7 @@ class CogVideoXPipeline(DiffusionPipeline):
hidden_states=latent_model_input_tile, hidden_states=latent_model_input_tile,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=t_input, timestep=t_input,
image_rotary_emb=image_rotary_emb,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
@ -429,21 +498,7 @@ class CogVideoXPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0] latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
else:
raise NotImplementedError("DPM is not supported with temporal tiling")
# else:
# latents_tile, old_pred_original_sample = self.scheduler.step(
# noise_pred,
# old_pred_original_sample,
# t,
# t_input[t_i - 1] if t_i > 0 else None,
# latents_tile,
# **extra_step_kwargs,
# return_dict=False,
# )
latents_all_list.append(latents_tile) latents_all_list.append(latents_tile)
# ========================================== # ==========================================
@ -465,13 +520,57 @@ class CogVideoXPipeline(DiffusionPipeline):
latents_all /= contributors latents_all /= contributors
latents = latents_all latents = latents_all
# ==========================================
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
comfy_pbar.update(1)
# ==========================================
else:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
comfy_pbar.update(1) comfy_pbar.update(1)
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()

View File

@ -1,2 +1,2 @@
huggingface_hub huggingface_hub
diffusers>=0.30.0 diffusers>=0.30.1