diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 21784f6..98f2880 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -19,10 +19,6 @@ import torch from torch import nn import torch.nn.functional as F -import os -import json -import glob - import numpy as np from einops import rearrange @@ -416,7 +412,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): patch_bias: bool = True, is_train_face: bool = False, is_kps: bool = False, - cross_attn_interval: int = 1, + cross_attn_interval: int = 2, LFE_num_tokens: int = 32, LFE_output_dim: int = 768, LFE_heads: int = 12, @@ -519,7 +515,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.LFE_output_dim = LFE_output_dim self.LFE_heads = LFE_heads self.LFE_final_output_dim = int(self.inner_dim / 3 * 2) - self.local_face_scale = local_face_scale self._init_face_inputs() @@ -534,137 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.perceiver_cross_attention = nn.ModuleList([ PerceiverCrossAttention(dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim).to(device, dtype=weight_dtype) for _ in range(self.num_ca) ]) - @classmethod - def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}): - if subfolder: - config_path = config_path or pretrained_model_path - config_file = os.path.join(config_path, subfolder, 'config.json') - pretrained_model_path = os.path.join(pretrained_model_path, subfolder) - else: - config_file = os.path.join(config_path or pretrained_model_path, 'config.json') - - print(f"Loading 3D transformer's pretrained weights from {pretrained_model_path} ...") - - # Check if config file exists - if not os.path.isfile(config_file): - raise RuntimeError(f"Configuration file '{config_file}' does not exist") - - # Load the configuration - with open(config_file, "r") as f: - config = json.load(f) - - from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **transformer_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) - model_file_safetensors = model_file.replace(".bin", ".safetensors") - if os.path.exists(model_file): - state_dict = torch.load(model_file, map_location="cpu") - elif os.path.exists(model_file_safetensors): - from safetensors.torch import load_file - state_dict = load_file(model_file_safetensors) - else: - from safetensors.torch import load_file - model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) - state_dict = {} - for model_file_safetensors in model_files_safetensors: - _state_dict = load_file(model_file_safetensors) - for key in _state_dict: - state_dict[key] = _state_dict[key] - - if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size(): - new_shape = model.state_dict()['patch_embed.proj.weight'].size() - if len(new_shape) == 5: - state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() - state_dict['patch_embed.proj.weight'][:, :, :-1] = 0 - else: - if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: - model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight'] - model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0 - state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] - else: - model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :] - state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] - - tmp_state_dict = {} - for key in state_dict: - if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): - tmp_state_dict[key] = state_dict[key] - else: - print(key, "Size don't match, skip") - state_dict = tmp_state_dict - - m, u = model.load_state_dict(state_dict, strict=False) - print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") - print(m) - - params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()] - print(f"### Mamba Parameters: {sum(params) / 1e6} M") - - params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] - print(f"### attn1 Parameters: {sum(params) / 1e6} M") - - return model - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def forward( self, @@ -724,6 +588,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.fastercache_counter+=1 if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: # 3. Transformer blocks + ca_idx = 0 for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states[:1], @@ -748,6 +613,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): controlnet_block_weight = controlnet_weights hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight + + # ConsisID + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + consis_id["scale"] * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B @@ -828,7 +699,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # ConsisID if self.is_train_face: if i % self.cross_attn_interval == 0 and valid_face_emb is not None: - hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + hidden_states = hidden_states + consis_id["scale"] * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) ca_idx += 1 if not self.config.use_rotary_positional_embeddings: diff --git a/model_loading.py b/model_loading.py index 68c3ccf..c226a6b 100644 --- a/model_loading.py +++ b/model_loading.py @@ -218,24 +218,8 @@ class DownloadAndLoadCogVideoModel: local_dir=download_path, local_dir_use_symlinks=False, ) - # transformer_additional_kwargs={} - if "consisid" in model.lower(): - # transformer_additional_kwargs={ - # 'torch_dtype': dtype, - # 'revision': None, - # 'variant': None, - # 'is_train_face': True, - # 'is_kps': False, - # 'LFE_num_tokens': 32, - # 'LFE_output_dim': 768, - # 'LFE_heads': 12, - # 'cross_attn_interval': 2, - # } - transformer = CogVideoXTransformer3DModel.from_pretrained_cus(base_path, subfolder=subfolder) - else: - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) - + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = transformer.to(dtype).to(transformer_load_device) if "1.5" in model: diff --git a/nodes_consis_id.py b/nodes_consis_id.py index 5a38d71..23c480f 100644 --- a/nodes_consis_id.py +++ b/nodes_consis_id.py @@ -66,8 +66,6 @@ class DownloadAndLoadConsisIDModel: eva_transform_mean = (eva_transform_mean,) * 3 if not isinstance(eva_transform_std, (list, tuple)): eva_transform_std = (eva_transform_std,) * 3 - eva_transform_mean = eva_transform_mean - eva_transform_std = eva_transform_std face_main_model = FaceAnalysis(name='antelopev2', root=face_encoder_path, providers=[onnx_device + 'ExecutionProvider',]) handler_ante = insightface.model_zoo.get_model(f'{face_encoder_path}/models/antelopev2/glintr100.onnx', providers=[onnx_device + 'ExecutionProvider',]) @@ -100,6 +98,7 @@ class ConsisIDFaceEncode: "required": { "consis_id_model": ("CONSISIDMODEL",), "image": ("IMAGE",), + "face_scale": ("FLOAT", {"default": 1.0,"step": 0.01},), }, } @@ -109,7 +108,7 @@ class ConsisIDFaceEncode: CATEGORY = "CogVideoWrapper" DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'" - def faceencode(self, image, consis_id_model): + def faceencode(self, image, consis_id_model, face_scale): from .consis_id.models.utils import process_face_embeddings device = mm.get_torch_device() @@ -117,24 +116,29 @@ class ConsisIDFaceEncode: id_image = image[0].cpu().numpy() * 255 - face_helper = consis_id_model["face_helper"] - face_clip_model = consis_id_model["face_clip_model"] - handler_ante = consis_id_model["handler_ante"] - eva_transform_mean = consis_id_model["eva_transform_mean"] - eva_transform_std = consis_id_model["eva_transform_std"] - face_main_model = consis_id_model["face_main_model"] - id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante, - eva_transform_mean, eva_transform_std, - face_main_model, device, dtype, id_image, - original_id_image=id_image, is_align_face=True, - cal_uncond=False) + id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings( + consis_id_model["face_helper"], + consis_id_model["face_clip_model"], + consis_id_model["handler_ante"], + consis_id_model["eva_transform_mean"], + consis_id_model["eva_transform_std"], + consis_id_model["face_main_model"], + device, + dtype, + id_image, + original_id_image=id_image, + is_align_face=True, + cal_uncond=False + ) + consis_id_conds = { "id_cond": id_cond, "id_vit_hidden": id_vit_hidden, + "scale": face_scale, #"align_crop_face_image": align_crop_face_image, #"face_kps": face_kps } - print(align_crop_face_image.shape) + #print(align_crop_face_image.shape) align_crop_face_image = align_crop_face_image.permute(0, 2, 3, 1).float().cpu() return consis_id_conds, align_crop_face_image, diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 98bc47c..a11185c 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -677,6 +677,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): timestep=timestep, image_rotary_emb=image_rotary_emb, video_flow_features=partial_video_flow_features, + consis_id=consis_id, return_dict=False )[0]