This commit is contained in:
kijai 2024-11-28 17:32:49 +02:00
parent d57ec230d5
commit 3663162f02
4 changed files with 30 additions and 170 deletions

View File

@ -19,10 +19,6 @@ import torch
from torch import nn
import torch.nn.functional as F
import os
import json
import glob
import numpy as np
from einops import rearrange
@ -416,7 +412,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
patch_bias: bool = True,
is_train_face: bool = False,
is_kps: bool = False,
cross_attn_interval: int = 1,
cross_attn_interval: int = 2,
LFE_num_tokens: int = 32,
LFE_output_dim: int = 768,
LFE_heads: int = 12,
@ -519,7 +515,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.LFE_output_dim = LFE_output_dim
self.LFE_heads = LFE_heads
self.LFE_final_output_dim = int(self.inner_dim / 3 * 2)
self.local_face_scale = local_face_scale
self._init_face_inputs()
@ -534,137 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.perceiver_cross_attention = nn.ModuleList([
PerceiverCrossAttention(dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim).to(device, dtype=weight_dtype) for _ in range(self.num_ca)
])
@classmethod
def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}):
if subfolder:
config_path = config_path or pretrained_model_path
config_file = os.path.join(config_path, subfolder, 'config.json')
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
else:
config_file = os.path.join(config_path or pretrained_model_path, 'config.json')
print(f"Loading 3D transformer's pretrained weights from {pretrained_model_path} ...")
# Check if config file exists
if not os.path.isfile(config_file):
raise RuntimeError(f"Configuration file '{config_file}' does not exist")
# Load the configuration
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file):
state_dict = torch.load(model_file, map_location="cpu")
elif os.path.exists(model_file_safetensors):
from safetensors.torch import load_file
state_dict = load_file(model_file_safetensors)
else:
from safetensors.torch import load_file
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
state_dict = {}
for model_file_safetensors in model_files_safetensors:
_state_dict = load_file(model_file_safetensors)
for key in _state_dict:
state_dict[key] = _state_dict[key]
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
else:
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
else:
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
print(m)
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
print(f"### Mamba Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
return model
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
@ -724,6 +588,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
# 3. Transformer blocks
ca_idx = 0
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states[:1],
@ -748,6 +613,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
# ConsisID
if self.is_train_face:
if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
hidden_states = hidden_states + consis_id["scale"] * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
ca_idx += 1
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
@ -828,7 +699,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
# ConsisID
if self.is_train_face:
if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
hidden_states = hidden_states + consis_id["scale"] * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
ca_idx += 1
if not self.config.use_rotary_positional_embeddings:

View File

@ -218,24 +218,8 @@ class DownloadAndLoadCogVideoModel:
local_dir=download_path,
local_dir_use_symlinks=False,
)
# transformer_additional_kwargs={}
if "consisid" in model.lower():
# transformer_additional_kwargs={
# 'torch_dtype': dtype,
# 'revision': None,
# 'variant': None,
# 'is_train_face': True,
# 'is_kps': False,
# 'LFE_num_tokens': 32,
# 'LFE_output_dim': 768,
# 'LFE_heads': 12,
# 'cross_attn_interval': 2,
# }
transformer = CogVideoXTransformer3DModel.from_pretrained_cus(base_path, subfolder=subfolder)
else:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device)
if "1.5" in model:

View File

@ -66,8 +66,6 @@ class DownloadAndLoadConsisIDModel:
eva_transform_mean = (eva_transform_mean,) * 3
if not isinstance(eva_transform_std, (list, tuple)):
eva_transform_std = (eva_transform_std,) * 3
eva_transform_mean = eva_transform_mean
eva_transform_std = eva_transform_std
face_main_model = FaceAnalysis(name='antelopev2', root=face_encoder_path, providers=[onnx_device + 'ExecutionProvider',])
handler_ante = insightface.model_zoo.get_model(f'{face_encoder_path}/models/antelopev2/glintr100.onnx', providers=[onnx_device + 'ExecutionProvider',])
@ -100,6 +98,7 @@ class ConsisIDFaceEncode:
"required": {
"consis_id_model": ("CONSISIDMODEL",),
"image": ("IMAGE",),
"face_scale": ("FLOAT", {"default": 1.0,"step": 0.01},),
},
}
@ -109,7 +108,7 @@ class ConsisIDFaceEncode:
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def faceencode(self, image, consis_id_model):
def faceencode(self, image, consis_id_model, face_scale):
from .consis_id.models.utils import process_face_embeddings
device = mm.get_torch_device()
@ -117,24 +116,29 @@ class ConsisIDFaceEncode:
id_image = image[0].cpu().numpy() * 255
face_helper = consis_id_model["face_helper"]
face_clip_model = consis_id_model["face_clip_model"]
handler_ante = consis_id_model["handler_ante"]
eva_transform_mean = consis_id_model["eva_transform_mean"]
eva_transform_std = consis_id_model["eva_transform_std"]
face_main_model = consis_id_model["face_main_model"]
id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
eva_transform_mean, eva_transform_std,
face_main_model, device, dtype, id_image,
original_id_image=id_image, is_align_face=True,
cal_uncond=False)
id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(
consis_id_model["face_helper"],
consis_id_model["face_clip_model"],
consis_id_model["handler_ante"],
consis_id_model["eva_transform_mean"],
consis_id_model["eva_transform_std"],
consis_id_model["face_main_model"],
device,
dtype,
id_image,
original_id_image=id_image,
is_align_face=True,
cal_uncond=False
)
consis_id_conds = {
"id_cond": id_cond,
"id_vit_hidden": id_vit_hidden,
"scale": face_scale,
#"align_crop_face_image": align_crop_face_image,
#"face_kps": face_kps
}
print(align_crop_face_image.shape)
#print(align_crop_face_image.shape)
align_crop_face_image = align_crop_face_image.permute(0, 2, 3, 1).float().cpu()
return consis_id_conds, align_crop_face_image,

View File

@ -677,6 +677,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep=timestep,
image_rotary_emb=image_rotary_emb,
video_flow_features=partial_video_flow_features,
consis_id=consis_id,
return_dict=False
)[0]