mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-15 20:56:56 +08:00
Experimenting with block removal
This commit is contained in:
parent
d4958e4c36
commit
d0d7308da5
56
nodes.py
56
nodes.py
@ -176,6 +176,33 @@ class CogVideoPABConfig:
|
|||||||
|
|
||||||
return (pab_config, )
|
return (pab_config, )
|
||||||
|
|
||||||
|
def remove_specific_blocks(model, block_indices_to_remove):
|
||||||
|
import torch.nn as nn
|
||||||
|
transformer_blocks = model.transformer_blocks
|
||||||
|
new_blocks = [block for i, block in enumerate(transformer_blocks) if i not in block_indices_to_remove]
|
||||||
|
model.transformer_blocks = nn.ModuleList(new_blocks)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
class CogVideoTransformerEdit:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"remove_blocks": ("STRING", {"default": "40", "multiline": True, "tooltip": "Comma separated list of block indices to remove, 5b blocks: 0-41, 2b model blocks 0-29"} ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TRANSFORMERBLOCKS",)
|
||||||
|
RETURN_NAMES = ("block_list", )
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "EXPERIMENTAL:Remove specific transformer blocks from the model"
|
||||||
|
|
||||||
|
def process(self, remove_blocks):
|
||||||
|
blocks_to_remove = [int(x.strip()) for x in remove_blocks.split(',')]
|
||||||
|
log.info(f"Blocks selected for removal: {blocks_to_remove}")
|
||||||
|
return (blocks_to_remove,)
|
||||||
|
|
||||||
class DownloadAndLoadCogVideoModel:
|
class DownloadAndLoadCogVideoModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -201,6 +228,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||||
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,7 +237,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None):
|
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None):
|
||||||
|
|
||||||
check_diffusers_version()
|
check_diffusers_version()
|
||||||
|
|
||||||
@ -268,6 +296,9 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer")
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer")
|
||||||
|
|
||||||
transformer = transformer.to(dtype).to(offload_device)
|
transformer = transformer.to(dtype).to(offload_device)
|
||||||
|
|
||||||
|
if block_edit is not None:
|
||||||
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
@ -348,6 +379,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||||
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -356,7 +388,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None):
|
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None):
|
||||||
|
|
||||||
check_diffusers_version()
|
check_diffusers_version()
|
||||||
|
|
||||||
@ -430,11 +462,17 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
param.data = param.data.to(torch.float16)
|
param.data = param.data.to(torch.float16)
|
||||||
else:
|
else:
|
||||||
transformer.to(torch.float8_e4m3fn)
|
transformer.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if block_edit is not None:
|
||||||
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
|
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
|
||||||
if load_device == "offload_device":
|
if load_device == "offload_device":
|
||||||
transformer.to(offload_device)
|
transformer.to(offload_device)
|
||||||
else:
|
else:
|
||||||
transformer.to(device)
|
transformer.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if fp8_fastmode:
|
if fp8_fastmode:
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
@ -731,7 +769,7 @@ class CogVideoImageEncode:
|
|||||||
|
|
||||||
# Concatenate all the chunks along the temporal dimension
|
# Concatenate all the chunks along the temporal dimension
|
||||||
final_latents = torch.cat(latents_list, dim=1)
|
final_latents = torch.cat(latents_list, dim=1)
|
||||||
print("final latents: ", final_latents.shape)
|
log.info(f"Encoded latents shape: {final_latents.shape}")
|
||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
@ -822,7 +860,6 @@ class CogVideoSampler:
|
|||||||
if not pipeline["cpu_offloading"]:
|
if not pipeline["cpu_offloading"]:
|
||||||
pipe.transformer.to(offload_device)
|
pipe.transformer.to(offload_device)
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
print(latents.shape)
|
|
||||||
|
|
||||||
return (pipeline, {"samples": latents})
|
return (pipeline, {"samples": latents})
|
||||||
|
|
||||||
@ -964,7 +1001,7 @@ class CogVideoXFunSampler:
|
|||||||
original_height = opt_empty_latent["samples"][0].shape[-2] * 8
|
original_height = opt_empty_latent["samples"][0].shape[-2] * 8
|
||||||
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
||||||
height, width = [int(x / 16) * 16 for x in closest_size]
|
height, width = [int(x / 16) * 16 for x in closest_size]
|
||||||
print(f"Closest size: {width}x{height}")
|
log.info(f"Closest bucket size: {width}x{height}")
|
||||||
|
|
||||||
# Load Sampler
|
# Load Sampler
|
||||||
scheduler_config = pipeline["scheduler_config"]
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
@ -1001,7 +1038,6 @@ class CogVideoXFunSampler:
|
|||||||
#if not pipeline["cpu_offloading"]:
|
#if not pipeline["cpu_offloading"]:
|
||||||
# pipe.transformer.to(offload_device)
|
# pipe.transformer.to(offload_device)
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
print(latents.shape)
|
|
||||||
|
|
||||||
return (pipeline, {"samples": latents})
|
return (pipeline, {"samples": latents})
|
||||||
|
|
||||||
@ -1077,8 +1113,6 @@ class CogVideoXFunVid2VidSampler:
|
|||||||
original_width, original_height = Image.fromarray(validation_video[0]).size
|
original_width, original_height = Image.fromarray(validation_video[0]).size
|
||||||
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
||||||
height, width = [int(x / 16) * 16 for x in closest_size]
|
height, width = [int(x / 16) * 16 for x in closest_size]
|
||||||
|
|
||||||
base_path = pipeline["base_path"]
|
|
||||||
|
|
||||||
# Load Sampler
|
# Load Sampler
|
||||||
scheduler_config = pipeline["scheduler_config"]
|
scheduler_config = pipeline["scheduler_config"]
|
||||||
@ -1130,7 +1164,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
||||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
||||||
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
|
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
|
||||||
"CogVideoPABConfig": CogVideoPABConfig
|
"CogVideoPABConfig": CogVideoPABConfig,
|
||||||
|
"CogVideoTransformerEdit": CogVideoTransformerEdit
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
@ -1143,5 +1178,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
||||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
||||||
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
|
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
|
||||||
"CogVideoPABConfig": "CogVideo PABConfig"
|
"CogVideoPABConfig": "CogVideo PABConfig",
|
||||||
|
"CogVideoTransformerEdit": "CogVideo TransformerEdit"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user