use diffusers LoRA loading to support fusing for DimensionX LoRAs

https://github.com/wenqsun/DimensionX
This commit is contained in:
kijai 2024-11-08 14:24:32 +02:00
parent 07defb52b6
commit 1cc6e1f070
6 changed files with 80 additions and 53 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ master_ip
logs/
*.DS_Store
.idea
*.pt

View File

@ -31,6 +31,7 @@ from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding,
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from diffusers.loaders import PeftAdapterMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -361,7 +362,7 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).

View File

@ -25,6 +25,42 @@ from comfy.utils import load_torch_file
script_directory = os.path.dirname(os.path.abspath(__file__))
class CogVideoLoraSelect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (folder_paths.get_filename_list("cogvideox_loras"),
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
"fuse_lora": ("BOOLEAN", {"default": False, "tooltip": "Fuse the LoRA weights into the transformer"}),
}
}
RETURN_TYPES = ("COGLORA",)
RETURN_NAMES = ("lora", )
FUNCTION = "getlorapath"
CATEGORY = "CogVideoWrapper"
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
cog_loras_list = []
cog_lora = {
"path": folder_paths.get_full_path("cogvideox_loras", lora),
"strength": strength,
"name": lora.split(".")[0],
"fuse_lora": fuse_lora
}
if prev_lora is not None:
cog_loras_list.extend(prev_lora)
cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
class DownloadAndLoadCogVideoModel:
@classmethod
def INPUT_TYPES(s):
@ -143,17 +179,6 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(offload_device)
#LoRAs
if lora is not None:
from .lora_utils import merge_lora, load_lora_into_transformer
if "fun" in model.lower():
for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
else:
transformer = load_lora_into_transformer(lora, transformer)
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
@ -185,6 +210,39 @@ class DownloadAndLoadCogVideoModel:
if "cogvideox-2b-img2vid" in model:
pipe.input_with_padding = False
#LoRAs
if lora is not None:
from .lora_utils import merge_lora#, load_lora_into_transformer
if "fun" in model.lower():
for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"])
else:
adapter_list = []
adapter_weights = []
for l in lora:
if l["fuse_lora"]:
fuse = True
lora_sd = load_torch_file(l["path"])
for key, val in lora_sd.items():
if "lora_B" in key:
lora_rank = val.shape[1]
break
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
adapter_name = l['path'].split("/")[-1].split(".")[0]
adapter_weight = l['strength']
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
#transformer = load_lora_into_transformer(lora, transformer)
adapter_list.append(adapter_name)
adapter_weights.append(adapter_weight)
for l in lora:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse:
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
@ -567,10 +625,12 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
"CogVideoLoraSelect": CogVideoLoraSelect,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
"DownloadAndLoadToraModel": "(Down)load Tora Model",
"CogVideoLoraSelect": "CogVideo LoraSelect",
}

View File

@ -148,40 +148,6 @@ class CogVideoTransformerEdit:
log.info(f"Blocks selected for removal: {blocks_to_remove}")
return (blocks_to_remove,)
class CogVideoLoraSelect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (folder_paths.get_filename_list("cogvideox_loras"),
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
}
}
RETURN_TYPES = ("COGLORA",)
RETURN_NAMES = ("lora", )
FUNCTION = "getlorapath"
CATEGORY = "CogVideoWrapper"
def getlorapath(self, lora, strength, prev_lora=None):
cog_loras_list = []
cog_lora = {
"path": folder_paths.get_full_path("cogvideox_loras", lora),
"strength": strength,
"name": lora.split(".")[0],
}
if prev_lora is not None:
cog_loras_list.extend(prev_lora)
cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
class CogVideoXTorchCompileSettings:
@classmethod
def INPUT_TYPES(s):
@ -1399,7 +1365,6 @@ NODE_CLASS_MAPPINGS = {
"CogVideoPABConfig": CogVideoPABConfig,
"CogVideoTransformerEdit": CogVideoTransformerEdit,
"CogVideoControlImageEncode": CogVideoControlImageEncode,
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoContextOptions": CogVideoContextOptions,
"CogVideoControlNet": CogVideoControlNet,
"ToraEncodeTrajectory": ToraEncodeTrajectory,
@ -1423,7 +1388,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoPABConfig": "CogVideo PABConfig",
"CogVideoTransformerEdit": "CogVideo TransformerEdit",
"CogVideoControlImageEncode": "CogVideo Control ImageEncode",
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoContextOptions": "CogVideo Context Options",
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",

View File

@ -26,6 +26,7 @@ from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.loaders import CogVideoXLoraLoaderMixin
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
@ -113,7 +114,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXPipeline(VideoSysPipeline):
class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.

View File

@ -7,7 +7,7 @@ log = logging.getLogger(__name__)
def check_diffusers_version():
try:
version = importlib.metadata.version('diffusers')
required_version = '0.30.3'
required_version = '0.31.0'
if version < required_version:
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
except importlib.metadata.PackageNotFoundError: