This commit is contained in:
kijai 2025-06-20 00:10:24 +03:00
parent 1124c77d56
commit 1d11d779c5
5 changed files with 310 additions and 5 deletions

View File

@ -610,6 +610,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
video_flow_features: Optional[torch.Tensor] = None,
tracking_maps: Optional[torch.Tensor] = None,
EF_Net_states: torch.Tensor = None,
EF_Net_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
@ -784,7 +786,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#das
if i < len(self.transformer_blocks_copy) and tracking_maps is not None:
if hasattr(self, 'transformer_blocks_copy') and i < len(self.transformer_blocks_copy) and tracking_maps is not None:
tracking_maps, _ = self.transformer_blocks_copy[i](
hidden_states=tracking_maps,
encoder_hidden_states=encoder_hidden_states,
@ -795,6 +797,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
tracking_maps = self.combine_linears[i](tracking_maps)
hidden_states = hidden_states + tracking_maps
#Sci-Fi
if (EF_Net_states is not None) and (i < len(EF_Net_states)):
EF_Net_states_block = EF_Net_states[i]
EF_Net_block_weight = 1.0
if isinstance(EF_Net_weights, (float, int)):
EF_Net_block_weight = EF_Net_weights
else:
EF_Net_block_weight = EF_Net_weights[i]
hidden_states = hidden_states + EF_Net_states_block * EF_Net_block_weight
if self.use_teacache:
self.previous_residual = hidden_states - ori_hidden_states
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states

View File

@ -172,6 +172,9 @@ class DAS_SpaTracker:
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
print("pred_tracks: ", pred_tracks.shape)
print(pred_tracks[2])
tracking_video = vis.visualize(
video=video,
tracks=pred_tracks,

View File

@ -125,6 +125,34 @@ class CogVideoLoraSelectComfy:
print(cog_loras_list)
return (cog_loras_list,)
class CogVideoEF_Net:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ef_net": (folder_paths.get_filename_list("diffusion_models"),
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "start percent"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "end percent"}),
},
}
RETURN_TYPES = ("EFNET",)
RETURN_NAMES = ("ef_net", )
FUNCTION = "efnet"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Select a EF_Net model from ComfyUI/models/diffusion_models"
def efnet(self, ef_net, strength, start_percent, end_percent):
ef_net_dict = {
"path": folder_paths.get_full_path("diffusion_models", ef_net),
"strength": strength,
"start_percent": start_percent,
"end_percent": end_percent
}
return (ef_net_dict,)
#region DownloadAndLoadCogVideoModel
class DownloadAndLoadCogVideoModel:
@classmethod
@ -177,6 +205,7 @@ class DownloadAndLoadCogVideoModel:
"comfy"
], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
"scifi_ef_net": ("EFNET", ),
}
}
@ -188,7 +217,7 @@ class DownloadAndLoadCogVideoModel:
def loadmodel(self, model, precision, quantization="disabled", compile="disabled",
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
attention_mode="sdpa", load_device="main_device", scifi_ef_net=None):
transformer = None
@ -356,6 +385,14 @@ class DownloadAndLoadCogVideoModel:
if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last)
if scifi_ef_net is not None:
from .scifi.EF_Net import EF_Net
EF_Net_model = EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48,).requires_grad_(False).eval()
sd = load_torch_file(scifi_ef_net["path"])
EF_Net_model.load_state_dict(sd, strict=True)
pipe.EF_Net_model = EF_Net_model
del sd
#fp8
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
@ -692,6 +729,7 @@ class CogVideoXModelLoader:
"fused_sageattn_qk_int8_pv_fp16_triton",
"comfy"
], {"default": "sdpa"}),
"scifi_ef_net": ("EFNET", ),
}
}
@ -701,7 +739,7 @@ class CogVideoXModelLoader:
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled", scifi_ef_net=None):
transformer = None
if "sage" in attention_mode:
try:
@ -860,6 +898,15 @@ class CogVideoXModelLoader:
if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last)
if scifi_ef_net is not None:
from .scifi.EF_Net import EF_Net
EF_Net_model = EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48,).requires_grad_(False).eval()
sd = load_torch_file(scifi_ef_net["path"])
EF_Net_model.load_state_dict(sd, strict=True)
EF_Net_model.to(base_dtype)
pipe.EF_Net_model = EF_Net_model
del sd
#quantization
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
@ -1138,7 +1185,8 @@ NODE_CLASS_MAPPINGS = {
"CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoXVAELoader": CogVideoXVAELoader,
"CogVideoXModelLoader": CogVideoXModelLoader,
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy,
"CogVideoEF_Net": CogVideoEF_Net
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -1148,5 +1196,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoXVAELoader": "CogVideoX VAE Loader",
"CogVideoXModelLoader": "CogVideoX Model Loader",
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy"
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy",
"CogVideoEF_Net": "CogVideo EF_Net"
}

View File

@ -384,6 +384,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
image_cond_end_percent: float = 1.0,
feta_args: Optional[dict] = None,
das_tracking: Optional[dict] = None,
EF_Net_weights: Optional[Union[float, list, torch.FloatTensor]] = 1.0,
EF_Net_guidance_start: float = 0.0,
EF_Net_guidance_end: float = 1.0,
):
"""
@ -853,6 +856,23 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
else:
controlnet_states = controlnet_states.to(dtype=self.vae_dtype)
self.EF_Net_model.to(device)
EF_Net_states = []
if (EF_Net_guidance_start <= current_step_percentage < EF_Net_guidance_end):
# extract EF_Net hidden state
EF_Net_states = self.EF_Net_model(
hidden_states=latent_image_input[:,:,0:16,:,:],
encoder_hidden_states=prompt_embeds,
image_rotary_emb=None,
EF_Net_states=latent_image_input[:,12::,:,:,:],
timestep=timestep,
return_dict=False,
)[0]
if isinstance(EF_Net_states, (tuple, list)):
EF_Net_states = [x.to(dtype=self.transformer.dtype) for x in EF_Net_states]
else:
EF_Net_states = EF_Net_states.to(dtype=self.transformer.dtype)
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
@ -865,6 +885,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
controlnet_weights=control_weights,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
tracking_maps=tracking_maps_input,
EF_Net_states=EF_Net_states,
EF_Net_weights=EF_Net_weights,
)[0]
noise_pred = noise_pred.float()
if isinstance(self.scheduler, CogVideoXDPMScheduler):

216
scifi/EF_Net.py Normal file
View File

@ -0,0 +1,216 @@
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
from diffusers.utils import is_torch_version
from diffusers.loaders import PeftAdapterMixin
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
class EF_Net(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
vae_channels: int = 16,
in_channels: int = 3,
downscale_coef: int = 8,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
num_layers: int = 8,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 1,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
out_proj_dim = None,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_proj_dim = inner_dim
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
raise ValueError(
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
"issue at https://github.com/huggingface/diffusers/issues."
)
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=vae_channels,
embed_dim=inner_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=49,
temporal_compression_ratio=temporal_compression_ratio,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.patch_embed_first = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=vae_channels,
embed_dim=inner_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
self.weights = nn.ModuleList([nn.Linear(inner_dim, 13) for _ in range(num_layers)])
self.first_weights = nn.ModuleList([nn.Linear(2*inner_dim, inner_dim) for _ in range(num_layers)])
# 2. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.out_projectors = None
self.relu = nn.LeakyReLU(negative_slope=0.01)
if out_proj_dim is not None:
self.out_projectors = nn.ModuleList(
[nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, enable=False, gradient_checkpointing_func=None):
self.gradient_checkpointing = enable
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
EF_Net_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = EF_Net_states.shape
o_hidden_states = hidden_states
hidden_states = EF_Net_states
encoder_hidden_states_ = encoder_hidden_states
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
o_hidden_states = self.patch_embed_first(encoder_hidden_states_, o_hidden_states)
o_hidden_states = self.embedding_dropout(o_hidden_states)
text_seq_length = encoder_hidden_states_.shape[1]
o_hidden_states = o_hidden_states[:, text_seq_length:]
EF_Net_hidden_states = ()
# 2. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
#if self.training and self.gradient_checkpointing:
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if self.out_projectors is not None:
coff = self.weights[i](hidden_states)
temp_list = []
for j in range(coff.shape[2]):
temp_list.append(hidden_states*coff[:,:,j:(j+1)])
out = torch.concat(temp_list, dim=1)
out = torch.concat([out, o_hidden_states], dim=2)
out = self.first_weights[i](out)
out = self.relu(out)
out = self.out_projectors[i](out)
EF_Net_hidden_states += (out,)
else:
out = torch.concat([weight*hidden_states for weight in self.weights], dim=1)
EF_Net_hidden_states += (out,)
if not return_dict:
return (EF_Net_hidden_states,)
return Transformer2DModelOutput(sample=EF_Net_hidden_states)