From 1d11d779c5f645620f2c051f0017fdfe4a845846 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 20 Jun 2025 00:10:24 +0300 Subject: [PATCH] push --- custom_cogvideox_transformer_3d.py | 17 ++- das/das_nodes.py | 3 + model_loading.py | 57 +++++++- pipeline_cogvideox.py | 22 +++ scifi/EF_Net.py | 216 +++++++++++++++++++++++++++++ 5 files changed, 310 insertions(+), 5 deletions(-) create mode 100644 scifi/EF_Net.py diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index a459aba..4844a2d 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -610,6 +610,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, video_flow_features: Optional[torch.Tensor] = None, tracking_maps: Optional[torch.Tensor] = None, + EF_Net_states: torch.Tensor = None, + EF_Net_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -784,7 +786,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight #das - if i < len(self.transformer_blocks_copy) and tracking_maps is not None: + if hasattr(self, 'transformer_blocks_copy') and i < len(self.transformer_blocks_copy) and tracking_maps is not None: tracking_maps, _ = self.transformer_blocks_copy[i]( hidden_states=tracking_maps, encoder_hidden_states=encoder_hidden_states, @@ -795,6 +797,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): tracking_maps = self.combine_linears[i](tracking_maps) hidden_states = hidden_states + tracking_maps + #Sci-Fi + if (EF_Net_states is not None) and (i < len(EF_Net_states)): + EF_Net_states_block = EF_Net_states[i] + EF_Net_block_weight = 1.0 + + if isinstance(EF_Net_weights, (float, int)): + EF_Net_block_weight = EF_Net_weights + else: + EF_Net_block_weight = EF_Net_weights[i] + + + hidden_states = hidden_states + EF_Net_states_block * EF_Net_block_weight + if self.use_teacache: self.previous_residual = hidden_states - ori_hidden_states self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states diff --git a/das/das_nodes.py b/das/das_nodes.py index 529e16e..50b280d 100644 --- a/das/das_nodes.py +++ b/das/das_nodes.py @@ -171,6 +171,9 @@ class DAS_SpaTracker: msk_query = (T_Firsts == 0) pred_tracks = pred_tracks[:,:,msk_query.squeeze()] pred_visibility = pred_visibility[:,:,msk_query.squeeze()] + + print("pred_tracks: ", pred_tracks.shape) + print(pred_tracks[2]) tracking_video = vis.visualize( video=video, diff --git a/model_loading.py b/model_loading.py index 236b428..0fd0630 100644 --- a/model_loading.py +++ b/model_loading.py @@ -125,6 +125,34 @@ class CogVideoLoraSelectComfy: print(cog_loras_list) return (cog_loras_list,) +class CogVideoEF_Net: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ef_net": (folder_paths.get_filename_list("diffusion_models"), + {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "start percent"}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "end percent"}), + }, + } + + RETURN_TYPES = ("EFNET",) + RETURN_NAMES = ("ef_net", ) + FUNCTION = "efnet" + CATEGORY = "CogVideoWrapper" + DESCRIPTION = "Select a EF_Net model from ComfyUI/models/diffusion_models" + + def efnet(self, ef_net, strength, start_percent, end_percent): + ef_net_dict = { + "path": folder_paths.get_full_path("diffusion_models", ef_net), + "strength": strength, + "start_percent": start_percent, + "end_percent": end_percent + } + return (ef_net_dict,) + #region DownloadAndLoadCogVideoModel class DownloadAndLoadCogVideoModel: @classmethod @@ -177,6 +205,7 @@ class DownloadAndLoadCogVideoModel: "comfy" ], {"default": "sdpa"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), + "scifi_ef_net": ("EFNET", ), } } @@ -188,7 +217,7 @@ class DownloadAndLoadCogVideoModel: def loadmodel(self, model, precision, quantization="disabled", compile="disabled", enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None, - attention_mode="sdpa", load_device="main_device"): + attention_mode="sdpa", load_device="main_device", scifi_ef_net=None): transformer = None @@ -356,6 +385,14 @@ class DownloadAndLoadCogVideoModel: if compile_args is not None: pipe.transformer.to(memory_format=torch.channels_last) + if scifi_ef_net is not None: + from .scifi.EF_Net import EF_Net + EF_Net_model = EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48,).requires_grad_(False).eval() + sd = load_torch_file(scifi_ef_net["path"]) + EF_Net_model.load_state_dict(sd, strict=True) + pipe.EF_Net_model = EF_Net_model + del sd + #fp8 if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode": params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"} @@ -692,6 +729,7 @@ class CogVideoXModelLoader: "fused_sageattn_qk_int8_pv_fp16_triton", "comfy" ], {"default": "sdpa"}), + "scifi_ef_net": ("EFNET", ), } } @@ -701,7 +739,7 @@ class CogVideoXModelLoader: CATEGORY = "CogVideoWrapper" def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload, - block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"): + block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled", scifi_ef_net=None): transformer = None if "sage" in attention_mode: try: @@ -860,6 +898,15 @@ class CogVideoXModelLoader: if compile_args is not None: pipe.transformer.to(memory_format=torch.channels_last) + if scifi_ef_net is not None: + from .scifi.EF_Net import EF_Net + EF_Net_model = EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48,).requires_grad_(False).eval() + sd = load_torch_file(scifi_ef_net["path"]) + EF_Net_model.load_state_dict(sd, strict=True) + EF_Net_model.to(base_dtype) + pipe.EF_Net_model = EF_Net_model + del sd + #quantization if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast": params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"} @@ -1138,7 +1185,8 @@ NODE_CLASS_MAPPINGS = { "CogVideoLoraSelect": CogVideoLoraSelect, "CogVideoXVAELoader": CogVideoXVAELoader, "CogVideoXModelLoader": CogVideoXModelLoader, - "CogVideoLoraSelectComfy": CogVideoLoraSelectComfy + "CogVideoLoraSelectComfy": CogVideoLoraSelectComfy, + "CogVideoEF_Net": CogVideoEF_Net } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -1148,5 +1196,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoLoraSelect": "CogVideo LoraSelect", "CogVideoXVAELoader": "CogVideoX VAE Loader", "CogVideoXModelLoader": "CogVideoX Model Loader", - "CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy" + "CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy", + "CogVideoEF_Net": "CogVideo EF_Net" } \ No newline at end of file diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index bf182b4..6e6696c 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -384,6 +384,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): image_cond_end_percent: float = 1.0, feta_args: Optional[dict] = None, das_tracking: Optional[dict] = None, + EF_Net_weights: Optional[Union[float, list, torch.FloatTensor]] = 1.0, + EF_Net_guidance_start: float = 0.0, + EF_Net_guidance_end: float = 1.0, ): """ @@ -853,6 +856,23 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): else: controlnet_states = controlnet_states.to(dtype=self.vae_dtype) + self.EF_Net_model.to(device) + EF_Net_states = [] + if (EF_Net_guidance_start <= current_step_percentage < EF_Net_guidance_end): + # extract EF_Net hidden state + EF_Net_states = self.EF_Net_model( + hidden_states=latent_image_input[:,:,0:16,:,:], + encoder_hidden_states=prompt_embeds, + image_rotary_emb=None, + EF_Net_states=latent_image_input[:,12::,:,:,:], + timestep=timestep, + return_dict=False, + )[0] + if isinstance(EF_Net_states, (tuple, list)): + EF_Net_states = [x.to(dtype=self.transformer.dtype) for x in EF_Net_states] + else: + EF_Net_states = EF_Net_states.to(dtype=self.transformer.dtype) + # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, @@ -865,6 +885,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): controlnet_weights=control_weights, video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, tracking_maps=tracking_maps_input, + EF_Net_states=EF_Net_states, + EF_Net_weights=EF_Net_weights, )[0] noise_pred = noise_pred.float() if isinstance(self.scheduler, CogVideoXDPMScheduler): diff --git a/scifi/EF_Net.py b/scifi/EF_Net.py new file mode 100644 index 0000000..9f51669 --- /dev/null +++ b/scifi/EF_Net.py @@ -0,0 +1,216 @@ +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock +from diffusers.utils import is_torch_version +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config + + +class EF_Net(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + vae_channels: int = 16, + in_channels: int = 3, + downscale_coef: int = 8, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + num_layers: int = 8, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 1, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + out_proj_dim = None, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + out_proj_dim = inner_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=vae_channels, + embed_dim=inner_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=49, + temporal_compression_ratio=temporal_compression_ratio, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + + self.patch_embed_first = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=vae_channels, + embed_dim=inner_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + + self.embedding_dropout = nn.Dropout(dropout) + self.weights = nn.ModuleList([nn.Linear(inner_dim, 13) for _ in range(num_layers)]) + self.first_weights = nn.ModuleList([nn.Linear(2*inner_dim, inner_dim) for _ in range(num_layers)]) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + self.out_projectors = None + self.relu = nn.LeakyReLU(negative_slope=0.01) + + if out_proj_dim is not None: + self.out_projectors = nn.ModuleList( + [nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)] + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, enable=False, gradient_checkpointing_func=None): + self.gradient_checkpointing = enable + + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + EF_Net_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep_cond: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size, num_frames, channels, height, width = EF_Net_states.shape + o_hidden_states = hidden_states + hidden_states = EF_Net_states + encoder_hidden_states_ = encoder_hidden_states + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + o_hidden_states = self.patch_embed_first(encoder_hidden_states_, o_hidden_states) + o_hidden_states = self.embedding_dropout(o_hidden_states) + + text_seq_length = encoder_hidden_states_.shape[1] + o_hidden_states = o_hidden_states[:, text_seq_length:] + + EF_Net_hidden_states = () + # 2. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + #if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + + if self.out_projectors is not None: + coff = self.weights[i](hidden_states) + temp_list = [] + for j in range(coff.shape[2]): + temp_list.append(hidden_states*coff[:,:,j:(j+1)]) + out = torch.concat(temp_list, dim=1) + out = torch.concat([out, o_hidden_states], dim=2) + out = self.first_weights[i](out) + out = self.relu(out) + out = self.out_projectors[i](out) + EF_Net_hidden_states += (out,) + else: + out = torch.concat([weight*hidden_states for weight in self.weights], dim=1) + EF_Net_hidden_states += (out,) + + if not return_dict: + return (EF_Net_hidden_states,) + return Transformer2DModelOutput(sample=EF_Net_hidden_states) +