kijai 2024-09-22 16:07:32 +03:00
parent ffece2db59
commit 21675b296b
3 changed files with 11 additions and 4 deletions

View File

@ -21,7 +21,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel

View File

@ -1,6 +1,4 @@
{ {
"_class_name": "CogVideoXTransformer3DModel",
"_diffusers_version": "0.31.0.dev0",
"activation_fn": "gelu-approximate", "activation_fn": "gelu-approximate",
"attention_bias": true, "attention_bias": true,
"attention_head_dim": 64, "attention_head_dim": 64,

View File

@ -173,6 +173,7 @@ class DownloadAndLoadCogVideoGGUFModel:
"model": ( "model": (
[ [
"CogVideoX_5b_GGUF_Q4_0.safetensors", "CogVideoX_5b_GGUF_Q4_0.safetensors",
"CogVideoX_5b_I2V_GGUF_Q4_0.safetensors",
"CogVideoX_5b_fun_GGUF_Q4_0.safetensors", "CogVideoX_5b_fun_GGUF_Q4_0.safetensors",
], ],
), ),
@ -198,11 +199,15 @@ class DownloadAndLoadCogVideoGGUFModel:
if not os.path.exists(gguf_path): if not os.path.exists(gguf_path):
gguf_path = os.path.join(download_path, model) gguf_path = os.path.join(download_path, model)
if not os.path.exists(gguf_path): if not os.path.exists(gguf_path):
if "I2V" in model:
repo_id = "Kijai/CogVideoX_GGUF"
else:
repo_id = "MinusZoneAI/ComfyUI-CogVideoX-MZ"
log.info(f"Downloading model to: {gguf_path}") log.info(f"Downloading model to: {gguf_path}")
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download( snapshot_download(
repo_id="MinusZoneAI/ComfyUI-CogVideoX-MZ", repo_id=repo_id,
allow_patterns=[f"*{model}*"], allow_patterns=[f"*{model}*"],
local_dir=download_path, local_dir=download_path,
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
@ -212,6 +217,8 @@ class DownloadAndLoadCogVideoGGUFModel:
with open(os.path.join(script_directory, 'configs', 'transformer_config_5b.json')) as f: with open(os.path.join(script_directory, 'configs', 'transformer_config_5b.json')) as f:
transformer_config = json.load(f) transformer_config = json.load(f)
sd = load_torch_file(gguf_path) sd = load_torch_file(gguf_path)
for key, value in sd.items():
print(key, value.shape, value.dtype)
from . import mz_gguf_loader from . import mz_gguf_loader
import importlib import importlib
@ -221,6 +228,9 @@ class DownloadAndLoadCogVideoGGUFModel:
if "fun" in model: if "fun" in model:
transformer_config["in_channels"] = 33 transformer_config["in_channels"] = 33
transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config) transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config)
elif "I2V" in model:
transformer_config["in_channels"] = 32
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
else: else:
transformer_config["in_channels"] = 16 transformer_config["in_channels"] = 16
transformer = CogVideoXTransformer3DModel.from_config(transformer_config) transformer = CogVideoXTransformer3DModel.from_config(transformer_config)