Experimenting with block removal

This commit is contained in:
kijai 2024-09-23 14:17:34 +03:00
parent d4958e4c36
commit d0d7308da5

View File

@ -176,6 +176,33 @@ class CogVideoPABConfig:
return (pab_config, )
def remove_specific_blocks(model, block_indices_to_remove):
import torch.nn as nn
transformer_blocks = model.transformer_blocks
new_blocks = [block for i, block in enumerate(transformer_blocks) if i not in block_indices_to_remove]
model.transformer_blocks = nn.ModuleList(new_blocks)
return model
class CogVideoTransformerEdit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"remove_blocks": ("STRING", {"default": "40", "multiline": True, "tooltip": "Comma separated list of block indices to remove, 5b blocks: 0-41, 2b model blocks 0-29"} ),
}
}
RETURN_TYPES = ("TRANSFORMERBLOCKS",)
RETURN_NAMES = ("block_list", )
FUNCTION = "process"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "EXPERIMENTAL:Remove specific transformer blocks from the model"
def process(self, remove_blocks):
blocks_to_remove = [int(x.strip()) for x in remove_blocks.split(',')]
log.info(f"Blocks selected for removal: {blocks_to_remove}")
return (blocks_to_remove,)
class DownloadAndLoadCogVideoModel:
@classmethod
def INPUT_TYPES(s):
@ -201,6 +228,7 @@ class DownloadAndLoadCogVideoModel:
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
}
}
@ -209,7 +237,7 @@ class DownloadAndLoadCogVideoModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None):
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None):
check_diffusers_version()
@ -268,6 +296,9 @@ class DownloadAndLoadCogVideoModel:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer")
transformer = transformer.to(dtype).to(offload_device)
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
if "2b" in model:
@ -348,6 +379,7 @@ class DownloadAndLoadCogVideoGGUFModel:
},
"optional": {
"pab_config": ("PAB_CONFIG", {"default": None}),
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
}
}
@ -356,7 +388,7 @@ class DownloadAndLoadCogVideoGGUFModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None):
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None):
check_diffusers_version()
@ -430,11 +462,17 @@ class DownloadAndLoadCogVideoGGUFModel:
param.data = param.data.to(torch.float16)
else:
transformer.to(torch.float8_e4m3fn)
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
transformer = mz_gguf_loader.quantize_load_state_dict(transformer, sd, device="cpu")
if load_device == "offload_device":
transformer.to(offload_device)
else:
transformer.to(device)
if fp8_fastmode:
from .fp8_optimization import convert_fp8_linear
@ -731,7 +769,7 @@ class CogVideoImageEncode:
# Concatenate all the chunks along the temporal dimension
final_latents = torch.cat(latents_list, dim=1)
print("final latents: ", final_latents.shape)
log.info(f"Encoded latents shape: {final_latents.shape}")
if not pipeline["cpu_offloading"]:
vae.to(offload_device)
@ -822,7 +860,6 @@ class CogVideoSampler:
if not pipeline["cpu_offloading"]:
pipe.transformer.to(offload_device)
mm.soft_empty_cache()
print(latents.shape)
return (pipeline, {"samples": latents})
@ -964,7 +1001,7 @@ class CogVideoXFunSampler:
original_height = opt_empty_latent["samples"][0].shape[-2] * 8
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
print(f"Closest size: {width}x{height}")
log.info(f"Closest bucket size: {width}x{height}")
# Load Sampler
scheduler_config = pipeline["scheduler_config"]
@ -1001,7 +1038,6 @@ class CogVideoXFunSampler:
#if not pipeline["cpu_offloading"]:
# pipe.transformer.to(offload_device)
mm.soft_empty_cache()
print(latents.shape)
return (pipeline, {"samples": latents})
@ -1077,8 +1113,6 @@ class CogVideoXFunVid2VidSampler:
original_width, original_height = Image.fromarray(validation_video[0]).size
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
base_path = pipeline["base_path"]
# Load Sampler
scheduler_config = pipeline["scheduler_config"]
@ -1130,7 +1164,8 @@ NODE_CLASS_MAPPINGS = {
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
"CogVideoPABConfig": CogVideoPABConfig
"CogVideoPABConfig": CogVideoPABConfig,
"CogVideoTransformerEdit": CogVideoTransformerEdit
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1143,5 +1178,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
"CogVideoPABConfig": "CogVideo PABConfig"
"CogVideoPABConfig": "CogVideo PABConfig",
"CogVideoTransformerEdit": "CogVideo TransformerEdit"
}