mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
use diffusers LoRA loading to support fusing for DimensionX LoRAs
https://github.com/wenqsun/DimensionX
This commit is contained in:
parent
07defb52b6
commit
1cc6e1f070
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,3 +7,4 @@ master_ip
|
|||||||
logs/
|
logs/
|
||||||
*.DS_Store
|
*.DS_Store
|
||||||
.idea
|
.idea
|
||||||
|
*.pt
|
||||||
@ -31,6 +31,7 @@ from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding,
|
|||||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||||
|
from diffusers.loaders import PeftAdapterMixin
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@ -361,7 +362,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
return hidden_states, encoder_hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||||
"""
|
"""
|
||||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,42 @@ from comfy.utils import load_torch_file
|
|||||||
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
class CogVideoLoraSelect:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"lora": (folder_paths.get_filename_list("cogvideox_loras"),
|
||||||
|
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
|
||||||
|
"fuse_lora": ("BOOLEAN", {"default": False, "tooltip": "Fuse the LoRA weights into the transformer"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("COGLORA",)
|
||||||
|
RETURN_NAMES = ("lora", )
|
||||||
|
FUNCTION = "getlorapath"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
|
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
||||||
|
cog_loras_list = []
|
||||||
|
|
||||||
|
cog_lora = {
|
||||||
|
"path": folder_paths.get_full_path("cogvideox_loras", lora),
|
||||||
|
"strength": strength,
|
||||||
|
"name": lora.split(".")[0],
|
||||||
|
"fuse_lora": fuse_lora
|
||||||
|
}
|
||||||
|
if prev_lora is not None:
|
||||||
|
cog_loras_list.extend(prev_lora)
|
||||||
|
|
||||||
|
cog_loras_list.append(cog_lora)
|
||||||
|
print(cog_loras_list)
|
||||||
|
return (cog_loras_list,)
|
||||||
|
|
||||||
class DownloadAndLoadCogVideoModel:
|
class DownloadAndLoadCogVideoModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -143,17 +179,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
transformer = transformer.to(dtype).to(offload_device)
|
transformer = transformer.to(dtype).to(offload_device)
|
||||||
|
|
||||||
#LoRAs
|
|
||||||
if lora is not None:
|
|
||||||
from .lora_utils import merge_lora, load_lora_into_transformer
|
|
||||||
if "fun" in model.lower():
|
|
||||||
for l in lora:
|
|
||||||
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
|
|
||||||
transformer = merge_lora(transformer, l["path"], l["strength"])
|
|
||||||
else:
|
|
||||||
transformer = load_lora_into_transformer(lora, transformer)
|
|
||||||
|
|
||||||
|
|
||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
@ -185,6 +210,39 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if "cogvideox-2b-img2vid" in model:
|
if "cogvideox-2b-img2vid" in model:
|
||||||
pipe.input_with_padding = False
|
pipe.input_with_padding = False
|
||||||
|
|
||||||
|
#LoRAs
|
||||||
|
if lora is not None:
|
||||||
|
from .lora_utils import merge_lora#, load_lora_into_transformer
|
||||||
|
if "fun" in model.lower():
|
||||||
|
for l in lora:
|
||||||
|
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
|
||||||
|
transformer = merge_lora(transformer, l["path"], l["strength"])
|
||||||
|
else:
|
||||||
|
adapter_list = []
|
||||||
|
adapter_weights = []
|
||||||
|
for l in lora:
|
||||||
|
if l["fuse_lora"]:
|
||||||
|
fuse = True
|
||||||
|
lora_sd = load_torch_file(l["path"])
|
||||||
|
for key, val in lora_sd.items():
|
||||||
|
if "lora_B" in key:
|
||||||
|
lora_rank = val.shape[1]
|
||||||
|
break
|
||||||
|
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
|
||||||
|
adapter_name = l['path'].split("/")[-1].split(".")[0]
|
||||||
|
adapter_weight = l['strength']
|
||||||
|
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
|
||||||
|
|
||||||
|
#transformer = load_lora_into_transformer(lora, transformer)
|
||||||
|
adapter_list.append(adapter_name)
|
||||||
|
adapter_weights.append(adapter_weight)
|
||||||
|
for l in lora:
|
||||||
|
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
|
||||||
|
if fuse:
|
||||||
|
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
@ -567,10 +625,12 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
|
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
|
||||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||||
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
||||||
|
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
|
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
|
||||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||||
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
||||||
|
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||||
}
|
}
|
||||||
36
nodes.py
36
nodes.py
@ -148,40 +148,6 @@ class CogVideoTransformerEdit:
|
|||||||
log.info(f"Blocks selected for removal: {blocks_to_remove}")
|
log.info(f"Blocks selected for removal: {blocks_to_remove}")
|
||||||
return (blocks_to_remove,)
|
return (blocks_to_remove,)
|
||||||
|
|
||||||
class CogVideoLoraSelect:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"lora": (folder_paths.get_filename_list("cogvideox_loras"),
|
|
||||||
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
|
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("COGLORA",)
|
|
||||||
RETURN_NAMES = ("lora", )
|
|
||||||
FUNCTION = "getlorapath"
|
|
||||||
CATEGORY = "CogVideoWrapper"
|
|
||||||
|
|
||||||
def getlorapath(self, lora, strength, prev_lora=None):
|
|
||||||
cog_loras_list = []
|
|
||||||
|
|
||||||
cog_lora = {
|
|
||||||
"path": folder_paths.get_full_path("cogvideox_loras", lora),
|
|
||||||
"strength": strength,
|
|
||||||
"name": lora.split(".")[0],
|
|
||||||
}
|
|
||||||
if prev_lora is not None:
|
|
||||||
cog_loras_list.extend(prev_lora)
|
|
||||||
|
|
||||||
cog_loras_list.append(cog_lora)
|
|
||||||
print(cog_loras_list)
|
|
||||||
return (cog_loras_list,)
|
|
||||||
|
|
||||||
class CogVideoXTorchCompileSettings:
|
class CogVideoXTorchCompileSettings:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1399,7 +1365,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoPABConfig": CogVideoPABConfig,
|
"CogVideoPABConfig": CogVideoPABConfig,
|
||||||
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
"CogVideoTransformerEdit": CogVideoTransformerEdit,
|
||||||
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
"CogVideoControlImageEncode": CogVideoControlImageEncode,
|
||||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
|
||||||
"CogVideoContextOptions": CogVideoContextOptions,
|
"CogVideoContextOptions": CogVideoContextOptions,
|
||||||
"CogVideoControlNet": CogVideoControlNet,
|
"CogVideoControlNet": CogVideoControlNet,
|
||||||
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
"ToraEncodeTrajectory": ToraEncodeTrajectory,
|
||||||
@ -1423,7 +1388,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoPABConfig": "CogVideo PABConfig",
|
"CogVideoPABConfig": "CogVideo PABConfig",
|
||||||
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
|
||||||
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
|
||||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
|
||||||
"CogVideoContextOptions": "CogVideo Context Options",
|
"CogVideoContextOptions": "CogVideo Context Options",
|
||||||
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
"ToraEncodeTrajectory": "Tora Encode Trajectory",
|
||||||
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from diffusers.utils import logging
|
|||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
from diffusers.video_processor import VideoProcessor
|
from diffusers.video_processor import VideoProcessor
|
||||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
|
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||||
|
|
||||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||||
|
|
||||||
@ -113,7 +114,7 @@ def retrieve_timesteps(
|
|||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
return timesteps, num_inference_steps
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
class CogVideoXPipeline(VideoSysPipeline):
|
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation using CogVideoX.
|
Pipeline for text-to-video generation using CogVideoX.
|
||||||
|
|
||||||
|
|||||||
2
utils.py
2
utils.py
@ -7,7 +7,7 @@ log = logging.getLogger(__name__)
|
|||||||
def check_diffusers_version():
|
def check_diffusers_version():
|
||||||
try:
|
try:
|
||||||
version = importlib.metadata.version('diffusers')
|
version = importlib.metadata.version('diffusers')
|
||||||
required_version = '0.30.3'
|
required_version = '0.31.0'
|
||||||
if version < required_version:
|
if version < required_version:
|
||||||
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
|
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
|
||||||
except importlib.metadata.PackageNotFoundError:
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user