diff --git a/configs/scheduler_config_2b.json b/configs/scheduler_config_2b.json new file mode 100644 index 0000000..1cd0fa6 --- /dev/null +++ b/configs/scheduler_config_2b.json @@ -0,0 +1,18 @@ +{ + "_class_name": "CogVideoXDDIMScheduler", + "_diffusers_version": "0.30.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": false, + "clip_sample_range": 1.0, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": true, + "sample_max_value": 1.0, + "set_alpha_to_one": true, + "snr_shift_scale": 3.0, + "steps_offset": 0, + "timestep_spacing": "trailing", + "trained_betas": null +} diff --git a/configs/scheduler_config_5b.json b/configs/scheduler_config_5b.json new file mode 100644 index 0000000..6e4f799 --- /dev/null +++ b/configs/scheduler_config_5b.json @@ -0,0 +1,18 @@ +{ + "_class_name": "CogVideoXDDIMScheduler", + "_diffusers_version": "0.31.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": false, + "clip_sample_range": 1.0, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": true, + "sample_max_value": 1.0, + "set_alpha_to_one": true, + "snr_shift_scale": 1.0, + "steps_offset": 0, + "timestep_spacing": "trailing", + "trained_betas": null +} \ No newline at end of file diff --git a/configs/transformer_config_2b.json b/configs/transformer_config_2b.json new file mode 100644 index 0000000..7336f77 --- /dev/null +++ b/configs/transformer_config_2b.json @@ -0,0 +1,28 @@ +{ + "_class_name": "CogVideoXTransformer3DModel", + "_diffusers_version": "0.30.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 64, + "dropout": 0.0, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 16, + "max_text_seq_length": 226, + "norm_elementwise_affine": true, + "norm_eps": 1e-05, + "num_attention_heads": 30, + "num_layers": 30, + "out_channels": 16, + "patch_size": 2, + "sample_frames": 49, + "sample_height": 60, + "sample_width": 90, + "spatial_interpolation_scale": 1.875, + "temporal_compression_ratio": 4, + "temporal_interpolation_scale": 1.0, + "text_embed_dim": 4096, + "time_embed_dim": 512, + "timestep_activation_fn": "silu", + "use_rotary_positional_embeddings": false + } \ No newline at end of file diff --git a/configs/transformer_config_5b.json b/configs/transformer_config_5b.json new file mode 100644 index 0000000..3041fd0 --- /dev/null +++ b/configs/transformer_config_5b.json @@ -0,0 +1,26 @@ +{ + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 64, + "dropout": 0.0, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 16, + "max_text_seq_length": 226, + "norm_elementwise_affine": true, + "norm_eps": 1e-05, + "num_attention_heads": 48, + "num_layers": 42, + "out_channels": 16, + "patch_size": 2, + "sample_frames": 49, + "sample_height": 60, + "sample_width": 90, + "spatial_interpolation_scale": 1.875, + "temporal_compression_ratio": 4, + "temporal_interpolation_scale": 1.0, + "text_embed_dim": 4096, + "time_embed_dim": 512, + "timestep_activation_fn": "silu", + "use_rotary_positional_embeddings": true + } \ No newline at end of file diff --git a/configs/transformer_config_I2V_5b.json b/configs/transformer_config_I2V_5b.json new file mode 100644 index 0000000..3265e76 --- /dev/null +++ b/configs/transformer_config_I2V_5b.json @@ -0,0 +1,29 @@ +{ + "_class_name": "CogVideoXTransformer3DModel", + "_diffusers_version": "0.31.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 64, + "dropout": 0.0, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 32, + "max_text_seq_length": 226, + "norm_elementwise_affine": true, + "norm_eps": 1e-05, + "num_attention_heads": 48, + "num_layers": 42, + "out_channels": 16, + "patch_size": 2, + "sample_frames": 49, + "sample_height": 60, + "sample_width": 90, + "spatial_interpolation_scale": 1.875, + "temporal_compression_ratio": 4, + "temporal_interpolation_scale": 1.0, + "text_embed_dim": 4096, + "time_embed_dim": 512, + "timestep_activation_fn": "silu", + "use_learned_positional_embeddings": true, + "use_rotary_positional_embeddings": true + } \ No newline at end of file diff --git a/configs/vae_config.json b/configs/vae_config.json new file mode 100644 index 0000000..99da9d2 --- /dev/null +++ b/configs/vae_config.json @@ -0,0 +1,39 @@ +{ + "_class_name": "AutoencoderKLCogVideoX", + "_diffusers_version": "0.31.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 256, + 512 + ], + "down_block_types": [ + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D" + ], + "force_upcast": true, + "in_channels": 3, + "latent_channels": 16, + "latents_mean": null, + "latents_std": null, + "layers_per_block": 3, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "out_channels": 3, + "sample_height": 480, + "sample_width": 720, + "scaling_factor": 0.7, + "shift_factor": null, + "temporal_compression_ratio": 4, + "up_block_types": [ + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D" + ], + "use_post_quant_conv": false, + "use_quant_conv": false +} diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py new file mode 100644 index 0000000..f5a6059 --- /dev/null +++ b/mz_gguf_loader.py @@ -0,0 +1,188 @@ +# https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/9616415220fd09388622f40f6609e4ed81f048a5/mz_gguf_loader.py + +import torch +import torch.nn as nn +import gc + + +class quantize_lazy_load(): + def __init__(self): + self.device = None + + def __enter__(self): + self.device = torch.device("meta") + self.device.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.device.__exit__(exc_type, exc_value, traceback) + + +def quantize_load_state_dict(model, state_dict, device="cpu"): + Q4_0_qkey = [] + for key in state_dict.keys(): + if key.endswith(".Q4_0_qweight"): + Q4_0_qkey.append(key.replace(".Q4_0_qweight", "")) + + for name, module in model.named_modules(): + if name in Q4_0_qkey: + q_linear = WQLinear_GGUF.from_linear( + linear=module, + device=device, + qtype="Q4_0", + ) + set_op_by_name(model, name, q_linear) + + model.to_empty(device=device) + model.load_state_dict(state_dict, strict=False) + model.to(device) + return model + + +def set_op_by_name(layer, name, new_module): + levels = name.split(".") + if len(levels) > 1: + mod_ = layer + for l_idx in range(len(levels) - 1): + if levels[l_idx].isdigit(): + mod_ = mod_[int(levels[l_idx])] + else: + mod_ = getattr(mod_, levels[l_idx]) + setattr(mod_, levels[-1], new_module) + else: + setattr(layer, name, new_module) + + +import torch.nn.functional as F + + +class WQLinear_GGUF(nn.Module): + def __init__( + self, in_features, out_features, bias, dev, qtype="Q4_0" + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.qtype = qtype + + qweight_shape = quant_shape_to_byte_shape( + (out_features, in_features), qtype + ) + self.register_buffer( + f"{qtype}_qweight", + torch.zeros( + qweight_shape, + dtype=torch.uint8, + device=dev, + ), + ) + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + device=dev, + ), + ) + else: + self.bias = None + + @classmethod + def from_linear( + cls, linear, + device="cpu", + qtype="Q4_0", + ): + q_linear = cls( + linear.in_features, + linear.out_features, + linear.bias is not None, + device, + qtype=qtype, + ) + return q_linear + + def extra_repr(self) -> str: + return ( + "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.w_bit, + self.group_size, + ) + ) + + @torch.no_grad() + def forward(self, x): + # x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight)) + if self.qtype == "Q4_0": + x = F.linear(x, dequantize_blocks_Q4_0( + self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + else: + raise ValueError(f"Unknown qtype: {self.qtype}") + + return x + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]: + # shape = shape[::-1] + block_size, type_size = GGML_QUANT_SIZES[qtype] + if shape[-1] % block_size != 0: + raise ValueError( + f"Quantized tensor row size ({shape[-1]}) is not a multiple of Q4_0 block size ({block_size})") + return (*shape[:-1], shape[-1] // block_size * type_size) + + +def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]: + # shape = shape[::-1] + block_size, type_size = GGML_QUANT_SIZES[qtype] + if shape[-1] % type_size != 0: + raise ValueError( + f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of Q4_0 type size ({type_size})") + return (*shape[:-1], shape[-1] // type_size * block_size) + + +GGML_QUANT_SIZES = { + "Q4_0": (32, 2 + 16), +} + + +def dequantize_blocks_Q4_0(data, dtype=torch.float16): + block_size, type_size = GGML_QUANT_SIZES["Q4_0"] + + data = data.to(torch.uint8) + shape = data.shape + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = data.reshape((n_blocks, type_size)) + + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + + out = (d * qs) + + out = out.reshape(quant_shape_from_byte_shape( + shape, + qtype="Q4_0", + )).to(dtype) + return out + diff --git a/nodes.py b/nodes.py index d4000d2..0ba47bf 100644 --- a/nodes.py +++ b/nodes.py @@ -15,11 +15,13 @@ from .cogvideox_fun.utils import get_image_to_video_latent, get_video_to_video_l from .cogvideox_fun.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint from PIL import Image import numpy as np +import json import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) +script_directory = os.path.dirname(os.path.abspath(__file__)) class DownloadAndLoadCogVideoModel: @classmethod @@ -65,19 +67,23 @@ class DownloadAndLoadCogVideoModel: download_path = os.path.join(folder_paths.models_dir, "CogVideo") if "2b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') if not os.path.exists(base_path): base_path = os.path.join(download_path, "CogVideoX-Fun-2b-InP") elif "5b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') if not os.path.exists(base_path): base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP") elif "2b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') download_path = base_path repo_id = model elif "5b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideo", (model.split("/")[-1])) + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') download_path = base_path repo_id = model @@ -91,7 +97,8 @@ class DownloadAndLoadCogVideoModel: local_dir=download_path, local_dir_use_symlinks=False, ) - + + # transformer if "Fun" in model: transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder="transformer") else: @@ -114,9 +121,12 @@ class DownloadAndLoadCogVideoModel: if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, dtype) - - scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") + with open(scheduler_path) as f: + scheduler_config = json.load(f) + scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) + + # VAE if "Fun" in model: vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) @@ -127,6 +137,7 @@ class DownloadAndLoadCogVideoModel: if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() + # compilation if compile == "torch": torch._dynamo.config.suppress_errors = True pipe.transformer.to(memory_format=torch.channels_last) @@ -148,7 +159,149 @@ class DownloadAndLoadCogVideoModel: "dtype": dtype, "base_path": base_path, "onediff": True if compile == "onediff" else False, - "cpu_offloading": enable_sequential_cpu_offload + "cpu_offloading": enable_sequential_cpu_offload, + "scheduler_config": scheduler_config + } + + return (pipeline,) + +class DownloadAndLoadCogVideoGGUFModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ( + [ + "CogVideoX_5b_GGUF_Q4_0.safetensors", + "CogVideoX_5b_fun_GGUF_Q4_0.safetensors", + ], + ), + "vae_precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "VAE dtype"}), + "fp8_fastmode": ("BOOLEAN", {"default": False, "tooltip": "only supported on 4090 and later GPUs"}), + "compile": (["disabled","onediff","torch"], {"tooltip": "UNTESTED WITH GGUF"}), + }, + } + + RETURN_TYPES = ("COGVIDEOPIPE",) + RETURN_NAMES = ("cogvideo_pipe", ) + FUNCTION = "loadmodel" + CATEGORY = "CogVideoWrapper" + + def loadmodel(self, model, vae_precision, compile, fp8_fastmode): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + vae_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[vae_precision] + download_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'GGUF') + gguf_path = os.path.join(folder_paths.models_dir, 'diffusion_models', model) # check MinusZone's model path first + if not os.path.exists(gguf_path): + gguf_path = os.path.join(download_path, model) + if not os.path.exists(gguf_path): + log.info(f"Downloading model to: {gguf_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id="MinusZoneAI/ComfyUI-CogVideoX-MZ", + allow_patterns=[f"*{model}*"], + local_dir=download_path, + local_dir_use_symlinks=False, + ) + + + with open(os.path.join(script_directory, 'configs', 'transformer_config_5b.json')) as f: + transformer_config = json.load(f) + sd = load_torch_file(gguf_path) + + from . import mz_gguf_loader + import importlib + importlib.reload(mz_gguf_loader) + + with mz_gguf_loader.quantize_lazy_load(): + if "fun" in model: + transformer_config["in_channels"] = 33 + transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config) + else: + transformer_config["in_channels"] = 16 + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + + transformer.to(torch.float8_e4m3fn) + transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu") + transformer.to(device) + + # transformer + # if fp8_transformer == "fastmode": + # if "2b" in model: + # for name, param in transformer.named_parameters(): + # if name != "pos_embedding": + # param.data = param.data.to(torch.float8_e4m3fn) + # elif "I2V" in model: + # for name, param in transformer.named_parameters(): + # if "patch_embed" not in name: + # param.data = param.data.to(torch.float8_e4m3fn) + # else: + # transformer.to(torch.float8_e4m3fn) + + if fp8_fastmode: + from .fp8_optimization import convert_fp8_linear + convert_fp8_linear(transformer, vae_dtype) + + scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_5b.json') + with open(scheduler_path) as f: + scheduler_config = json.load(f) + + scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config, subfolder="scheduler") + + # VAE + vae_dl_path = os.path.join(folder_paths.models_dir, 'CogVideo', 'VAE') + vae_path = os.path.join(vae_dl_path, "cogvideox_vae.safetensors") + if not os.path.exists(vae_path): + log.info(f"Downloading VAE model to: {vae_path}") + from huggingface_hub import snapshot_download + + snapshot_download( + repo_id="Kijai/CogVideoX-Fun-pruned", + allow_patterns=["*cogvideox_vae.safetensors*"], + local_dir=vae_dl_path, + local_dir_use_symlinks=False, + ) + with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f: + vae_config = json.load(f) + + vae_sd = load_torch_file(vae_path) + if "fun" in model: + vae = AutoencoderKLCogVideoXFun.from_config(vae_config).to(vae_dtype).to(offload_device) + vae.load_state_dict(vae_sd) + pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler) + else: + vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) + vae.load_state_dict(vae_sd) + pipe = CogVideoXPipeline(vae, transformer, scheduler) + + # compilation + if compile == "torch": + torch._dynamo.config.suppress_errors = True + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + elif compile == "onediff": + from onediffx import compile_pipe + os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' + + pipe = compile_pipe( + pipe, + backend="nexfort", + options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}}, + ignores=["vae"], + fuse_qkv_projections=True, + ) + + pipeline = { + "pipe": pipe, + "dtype": vae_dtype, + "base_path": "Fun" if "fun" in model else "sad", + "onediff": True if compile == "onediff" else False, + "cpu_offloading": False, + "scheduler_config": scheduler_config } return (pipeline,) @@ -613,20 +766,21 @@ class CogVideoXFunSampler: print(f"Closest size: {width}x{height}") # Load Sampler + scheduler_config = pipeline["scheduler_config"] if scheduler == "DPM++": - noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) elif scheduler == "Euler": - noise_scheduler = EulerDiscreteScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = EulerDiscreteScheduler.from_config(scheduler_config) elif scheduler == "Euler A": - noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) elif scheduler == "PNDM": - noise_scheduler = PNDMScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = PNDMScheduler.from_config(scheduler_config) elif scheduler == "DDIM": - noise_scheduler = DDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = DDIMScheduler.from_config(scheduler_config) elif scheduler == "CogVideoXDDIM": - noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) elif scheduler == "CogVideoXDPMScheduler": - noise_scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder= 'scheduler') + noise_scheduler = CogVideoXDPMScheduler.from_config(scheduler_config) pipe.scheduler = noise_scheduler #if not pipeline["cpu_offloading"]: @@ -784,7 +938,8 @@ NODE_CLASS_MAPPINGS = { "CogVideoImageEncode": CogVideoImageEncode, "CogVideoXFunSampler": CogVideoXFunSampler, "CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler, - "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine + "CogVideoTextEncodeCombine": CogVideoTextEncodeCombine, + "DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -795,5 +950,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoImageEncode": "CogVideo ImageEncode", "CogVideoXFunSampler": "CogVideoXFun Sampler", "CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler", - "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine" + "CogVideoTextEncodeCombine": "CogVideo TextEncode Combine", + "DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model" }