2024-08-07 02:11:37 +03:00

318 lines
11 KiB
Python

import os
import torch
import folder_paths
import comfy.model_management as mm
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from .pipeline_cogvideox import CogVideoXPipeline
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
class DownloadAndLoadCogVideoModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"precision": (
[
"fp16",
"fp32",
"bf16",
],
{"default": "fp16"},
),
},
}
RETURN_TYPES = ("COGVIDEOPIPE",)
RETURN_NAMES = ("cogvideo_pipe", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, precision):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
if not os.path.exists(base_path):
log.info(f"Downloading model to: {base_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="THUDM/CogVideoX-2b",
ignore_patterns=["*text_encoder*"],
local_dir=base_path,
local_dir_use_symlinks=False,
)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(dtype).to(offload_device)
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
pipe = CogVideoXPipeline(vae, transformer, scheduler)
pipeline = {
"pipe": pipe,
"dtype": dtype,
"base_path": base_path
}
return (pipeline,)
class CogVideoEncodePrompt:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"prompt": ("STRING", {"default": "", "multiline": True} ),
"negative_prompt": ("STRING", {"default": "", "multiline": True} ),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, prompt, negative_prompt):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
dtype = pipeline["dtype"]
pipe.text_encoder.to(device)
pipe.transformer.to(offload_device)
positive, negative = pipe.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
max_sequence_length=226,
device=device,
dtype=dtype,
)
pipe.text_encoder.to(offload_device)
return (positive, negative)
class CogVideoTextEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP",),
"prompt": ("STRING", {"default": "", "multiline": True} ),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("conditioning",)
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, clip, prompt):
load_device = mm.text_encoder_device()
offload_device = mm.text_encoder_offload_device()
clip.tokenizer.t5xxl.pad_to_max_length = True
clip.tokenizer.t5xxl.max_length = 226
clip.cond_stage_model.to(load_device)
tokens = clip.tokenize(prompt, return_word_ids=True)
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
clip.cond_stage_model.to(offload_device)
return (embeds, )
class CogVideoImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"image": ("IMAGE", ),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, image):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
vae = pipeline["pipe"].vae
vae.to(device)
image = image * 2.0 - 1.0
image = image.to(vae.dtype).to(device)
image = image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
B, C, T, H, W = image.shape
chunk_size = 16
latents_list = []
# Loop through the temporal dimension in chunks of 16
for i in range(0, T, chunk_size):
# Get the chunk of 16 frames (or remaining frames if less than 16 are left)
end_index = min(i + chunk_size, T)
image_chunk = image[:, :, i:end_index, :, :] # Shape: [B, C, chunk_size, H, W]
# Encode the chunk of images
latents = vae.encode(image_chunk)
sample_mode = "sample"
if hasattr(latents, "latent_dist") and sample_mode == "sample":
latents = latents.latent_dist.sample(generator)
elif hasattr(latents, "latent_dist") and sample_mode == "argmax":
latents = latents.latent_dist.mode()
elif hasattr(latents, "latents"):
latents = latents.latents
latents = vae.config.scaling_factor * latents
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
latents_list.append(latents)
# Concatenate all the chunks along the temporal dimension
final_latents = torch.cat(latents_list, dim=1)
print("final latents: ", final_latents.shape)
vae.to(offload_device)
return ({"samples": final_latents}, )
class CogVideoSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pipeline": ("COGVIDEOPIPE",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}),
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
"steps": ("INT", {"default": 25, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"scheduler": (["DDIM", "DPM"],),
},
"optional": {
"samples": ("LATENT", ),
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("COGVIDEOPIPE", "LATENT",)
RETURN_NAMES = ("cogvideo_pipe", "samples",)
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames, scheduler, samples=None, denoise_strength=1.0):
mm.soft_empty_cache()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipe = pipeline["pipe"]
dtype = pipeline["dtype"]
base_path = pipeline["base_path"]
pipe.transformer.to(device)
generator = torch.Generator(device=device).manual_seed(seed)
if scheduler == "DDIM":
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
elif scheduler == "DPM":
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
latents = pipeline["pipe"](
num_inference_steps=steps,
height = height,
width = width,
num_frames = num_frames,
fps = fps,
guidance_scale=cfg,
latents=samples["samples"] if samples is not None else None,
denoise_strength=denoise_strength,
prompt_embeds=positive.to(dtype).to(device),
negative_prompt_embeds=negative.to(dtype).to(device),
generator=generator,
device=device
)
pipe.transformer.to(offload_device)
mm.soft_empty_cache()
print(latents.shape)
pipeline["fps"] = fps
pipeline["num_frames"] = num_frames
return (pipeline, {"samples": latents})
class CogVideoDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"pipeline": ("COGVIDEOPIPE",),
"samples": ("LATENT", ),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
def process(self, pipeline, samples):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
latents = samples["samples"]
vae = pipeline["pipe"].vae
vae.to(device)
if "num_frames" in pipeline:
num_frames = pipeline["num_frames"]
fps = pipeline["fps"]
else:
num_frames = latents.shape[2]
fps = 8
num_seconds = num_frames // fps
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / vae.config.scaling_factor * latents
frames = []
for i in range(num_seconds):
# Whether or not to clear fake context parallel cache
fake_cp = i + 1 < num_seconds
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample
frames.append(current_frames)
mm.soft_empty_cache()
vae.to(offload_device)
frames = torch.cat(frames, dim=2)
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
video = video[0].permute(0, 2, 3, 1).cpu().float()
return (video,)
NODE_CLASS_MAPPINGS = {
"DownloadAndLoadCogVideoModel": DownloadAndLoadCogVideoModel,
"CogVideoSampler": CogVideoSampler,
"CogVideoDecode": CogVideoDecode,
"CogVideoTextEncode": CogVideoTextEncode,
"CogVideoImageEncode": CogVideoImageEncode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
"CogVideoSampler": "CogVideo Sampler",
"CogVideoDecode": "CogVideo Decode",
"CogVideoTextEncode": "CogVideo TextEncode",
"CogVideoImageEncode": "CogVideo ImageEncode"
}