mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
push
This commit is contained in:
parent
1124c77d56
commit
1d11d779c5
@ -610,6 +610,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
||||||
video_flow_features: Optional[torch.Tensor] = None,
|
video_flow_features: Optional[torch.Tensor] = None,
|
||||||
tracking_maps: Optional[torch.Tensor] = None,
|
tracking_maps: Optional[torch.Tensor] = None,
|
||||||
|
EF_Net_states: torch.Tensor = None,
|
||||||
|
EF_Net_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||||
@ -784,7 +786,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||||
|
|
||||||
#das
|
#das
|
||||||
if i < len(self.transformer_blocks_copy) and tracking_maps is not None:
|
if hasattr(self, 'transformer_blocks_copy') and i < len(self.transformer_blocks_copy) and tracking_maps is not None:
|
||||||
tracking_maps, _ = self.transformer_blocks_copy[i](
|
tracking_maps, _ = self.transformer_blocks_copy[i](
|
||||||
hidden_states=tracking_maps,
|
hidden_states=tracking_maps,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
@ -795,6 +797,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
tracking_maps = self.combine_linears[i](tracking_maps)
|
tracking_maps = self.combine_linears[i](tracking_maps)
|
||||||
hidden_states = hidden_states + tracking_maps
|
hidden_states = hidden_states + tracking_maps
|
||||||
|
|
||||||
|
#Sci-Fi
|
||||||
|
if (EF_Net_states is not None) and (i < len(EF_Net_states)):
|
||||||
|
EF_Net_states_block = EF_Net_states[i]
|
||||||
|
EF_Net_block_weight = 1.0
|
||||||
|
|
||||||
|
if isinstance(EF_Net_weights, (float, int)):
|
||||||
|
EF_Net_block_weight = EF_Net_weights
|
||||||
|
else:
|
||||||
|
EF_Net_block_weight = EF_Net_weights[i]
|
||||||
|
|
||||||
|
|
||||||
|
hidden_states = hidden_states + EF_Net_states_block * EF_Net_block_weight
|
||||||
|
|
||||||
if self.use_teacache:
|
if self.use_teacache:
|
||||||
self.previous_residual = hidden_states - ori_hidden_states
|
self.previous_residual = hidden_states - ori_hidden_states
|
||||||
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
||||||
|
|||||||
@ -172,6 +172,9 @@ class DAS_SpaTracker:
|
|||||||
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
||||||
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
||||||
|
|
||||||
|
print("pred_tracks: ", pred_tracks.shape)
|
||||||
|
print(pred_tracks[2])
|
||||||
|
|
||||||
tracking_video = vis.visualize(
|
tracking_video = vis.visualize(
|
||||||
video=video,
|
video=video,
|
||||||
tracks=pred_tracks,
|
tracks=pred_tracks,
|
||||||
|
|||||||
@ -125,6 +125,34 @@ class CogVideoLoraSelectComfy:
|
|||||||
print(cog_loras_list)
|
print(cog_loras_list)
|
||||||
return (cog_loras_list,)
|
return (cog_loras_list,)
|
||||||
|
|
||||||
|
class CogVideoEF_Net:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ef_net": (folder_paths.get_filename_list("diffusion_models"),
|
||||||
|
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||||
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "start percent"}),
|
||||||
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "end percent"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("EFNET",)
|
||||||
|
RETURN_NAMES = ("ef_net", )
|
||||||
|
FUNCTION = "efnet"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "Select a EF_Net model from ComfyUI/models/diffusion_models"
|
||||||
|
|
||||||
|
def efnet(self, ef_net, strength, start_percent, end_percent):
|
||||||
|
ef_net_dict = {
|
||||||
|
"path": folder_paths.get_full_path("diffusion_models", ef_net),
|
||||||
|
"strength": strength,
|
||||||
|
"start_percent": start_percent,
|
||||||
|
"end_percent": end_percent
|
||||||
|
}
|
||||||
|
return (ef_net_dict,)
|
||||||
|
|
||||||
#region DownloadAndLoadCogVideoModel
|
#region DownloadAndLoadCogVideoModel
|
||||||
class DownloadAndLoadCogVideoModel:
|
class DownloadAndLoadCogVideoModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -177,6 +205,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"comfy"
|
"comfy"
|
||||||
], {"default": "sdpa"}),
|
], {"default": "sdpa"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
|
"scifi_ef_net": ("EFNET", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,7 +217,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
def loadmodel(self, model, precision, quantization="disabled", compile="disabled",
|
def loadmodel(self, model, precision, quantization="disabled", compile="disabled",
|
||||||
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
||||||
attention_mode="sdpa", load_device="main_device"):
|
attention_mode="sdpa", load_device="main_device", scifi_ef_net=None):
|
||||||
|
|
||||||
transformer = None
|
transformer = None
|
||||||
|
|
||||||
@ -356,6 +385,14 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
if scifi_ef_net is not None:
|
||||||
|
from .scifi.EF_Net import EF_Net
|
||||||
|
EF_Net_model = EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48,).requires_grad_(False).eval()
|
||||||
|
sd = load_torch_file(scifi_ef_net["path"])
|
||||||
|
EF_Net_model.load_state_dict(sd, strict=True)
|
||||||
|
pipe.EF_Net_model = EF_Net_model
|
||||||
|
del sd
|
||||||
|
|
||||||
#fp8
|
#fp8
|
||||||
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode":
|
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fastmode":
|
||||||
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
|
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
|
||||||
@ -692,6 +729,7 @@ class CogVideoXModelLoader:
|
|||||||
"fused_sageattn_qk_int8_pv_fp16_triton",
|
"fused_sageattn_qk_int8_pv_fp16_triton",
|
||||||
"comfy"
|
"comfy"
|
||||||
], {"default": "sdpa"}),
|
], {"default": "sdpa"}),
|
||||||
|
"scifi_ef_net": ("EFNET", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -701,7 +739,7 @@ class CogVideoXModelLoader:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
|
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
|
||||||
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
|
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled", scifi_ef_net=None):
|
||||||
transformer = None
|
transformer = None
|
||||||
if "sage" in attention_mode:
|
if "sage" in attention_mode:
|
||||||
try:
|
try:
|
||||||
@ -860,6 +898,15 @@ class CogVideoXModelLoader:
|
|||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
if scifi_ef_net is not None:
|
||||||
|
from .scifi.EF_Net import EF_Net
|
||||||
|
EF_Net_model = EF_Net(num_layers=4, downscale_coef=8, in_channels=2, num_attention_heads=48,).requires_grad_(False).eval()
|
||||||
|
sd = load_torch_file(scifi_ef_net["path"])
|
||||||
|
EF_Net_model.load_state_dict(sd, strict=True)
|
||||||
|
EF_Net_model.to(base_dtype)
|
||||||
|
pipe.EF_Net_model = EF_Net_model
|
||||||
|
del sd
|
||||||
|
|
||||||
#quantization
|
#quantization
|
||||||
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast":
|
if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast":
|
||||||
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
|
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm_k", "norm_q", "to_k.bias", "to_q.bias", "to_v.bias"}
|
||||||
@ -1138,7 +1185,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||||
"CogVideoXVAELoader": CogVideoXVAELoader,
|
"CogVideoXVAELoader": CogVideoXVAELoader,
|
||||||
"CogVideoXModelLoader": CogVideoXModelLoader,
|
"CogVideoXModelLoader": CogVideoXModelLoader,
|
||||||
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy
|
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy,
|
||||||
|
"CogVideoEF_Net": CogVideoEF_Net
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
@ -1148,5 +1196,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||||
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
||||||
"CogVideoXModelLoader": "CogVideoX Model Loader",
|
"CogVideoXModelLoader": "CogVideoX Model Loader",
|
||||||
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy"
|
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy",
|
||||||
|
"CogVideoEF_Net": "CogVideo EF_Net"
|
||||||
}
|
}
|
||||||
@ -384,6 +384,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
image_cond_end_percent: float = 1.0,
|
image_cond_end_percent: float = 1.0,
|
||||||
feta_args: Optional[dict] = None,
|
feta_args: Optional[dict] = None,
|
||||||
das_tracking: Optional[dict] = None,
|
das_tracking: Optional[dict] = None,
|
||||||
|
EF_Net_weights: Optional[Union[float, list, torch.FloatTensor]] = 1.0,
|
||||||
|
EF_Net_guidance_start: float = 0.0,
|
||||||
|
EF_Net_guidance_end: float = 1.0,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -853,6 +856,23 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
else:
|
else:
|
||||||
controlnet_states = controlnet_states.to(dtype=self.vae_dtype)
|
controlnet_states = controlnet_states.to(dtype=self.vae_dtype)
|
||||||
|
|
||||||
|
self.EF_Net_model.to(device)
|
||||||
|
EF_Net_states = []
|
||||||
|
if (EF_Net_guidance_start <= current_step_percentage < EF_Net_guidance_end):
|
||||||
|
# extract EF_Net hidden state
|
||||||
|
EF_Net_states = self.EF_Net_model(
|
||||||
|
hidden_states=latent_image_input[:,:,0:16,:,:],
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
image_rotary_emb=None,
|
||||||
|
EF_Net_states=latent_image_input[:,12::,:,:,:],
|
||||||
|
timestep=timestep,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
if isinstance(EF_Net_states, (tuple, list)):
|
||||||
|
EF_Net_states = [x.to(dtype=self.transformer.dtype) for x in EF_Net_states]
|
||||||
|
else:
|
||||||
|
EF_Net_states = EF_Net_states.to(dtype=self.transformer.dtype)
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
hidden_states=latent_model_input,
|
hidden_states=latent_model_input,
|
||||||
@ -865,6 +885,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
controlnet_weights=control_weights,
|
controlnet_weights=control_weights,
|
||||||
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
|
||||||
tracking_maps=tracking_maps_input,
|
tracking_maps=tracking_maps_input,
|
||||||
|
EF_Net_states=EF_Net_states,
|
||||||
|
EF_Net_weights=EF_Net_weights,
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
if isinstance(self.scheduler, CogVideoXDPMScheduler):
|
if isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||||
|
|||||||
216
scifi/EF_Net.py
Normal file
216
scifi/EF_Net.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
|
||||||
|
from diffusers.utils import is_torch_version
|
||||||
|
from diffusers.loaders import PeftAdapterMixin
|
||||||
|
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||||
|
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
|
||||||
|
|
||||||
|
class EF_Net(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads: int = 30,
|
||||||
|
attention_head_dim: int = 64,
|
||||||
|
vae_channels: int = 16,
|
||||||
|
in_channels: int = 3,
|
||||||
|
downscale_coef: int = 8,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
time_embed_dim: int = 512,
|
||||||
|
num_layers: int = 8,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
attention_bias: bool = True,
|
||||||
|
sample_width: int = 90,
|
||||||
|
sample_height: int = 60,
|
||||||
|
sample_frames: int = 1,
|
||||||
|
patch_size: int = 2,
|
||||||
|
temporal_compression_ratio: int = 4,
|
||||||
|
max_text_seq_length: int = 226,
|
||||||
|
activation_fn: str = "gelu-approximate",
|
||||||
|
timestep_activation_fn: str = "silu",
|
||||||
|
norm_elementwise_affine: bool = True,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
spatial_interpolation_scale: float = 1.875,
|
||||||
|
temporal_interpolation_scale: float = 1.0,
|
||||||
|
use_rotary_positional_embeddings: bool = False,
|
||||||
|
use_learned_positional_embeddings: bool = False,
|
||||||
|
out_proj_dim = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
out_proj_dim = inner_dim
|
||||||
|
|
||||||
|
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
||||||
|
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
||||||
|
"issue at https://github.com/huggingface/diffusers/issues."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Patch embedding
|
||||||
|
self.patch_embed = CogVideoXPatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=vae_channels,
|
||||||
|
embed_dim=inner_dim,
|
||||||
|
bias=True,
|
||||||
|
sample_width=sample_width,
|
||||||
|
sample_height=sample_height,
|
||||||
|
sample_frames=49,
|
||||||
|
temporal_compression_ratio=temporal_compression_ratio,
|
||||||
|
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||||
|
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||||
|
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||||
|
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.patch_embed_first = CogVideoXPatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_channels=vae_channels,
|
||||||
|
embed_dim=inner_dim,
|
||||||
|
bias=True,
|
||||||
|
sample_width=sample_width,
|
||||||
|
sample_height=sample_height,
|
||||||
|
sample_frames=sample_frames,
|
||||||
|
temporal_compression_ratio=temporal_compression_ratio,
|
||||||
|
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||||
|
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||||
|
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||||
|
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedding_dropout = nn.Dropout(dropout)
|
||||||
|
self.weights = nn.ModuleList([nn.Linear(inner_dim, 13) for _ in range(num_layers)])
|
||||||
|
self.first_weights = nn.ModuleList([nn.Linear(2*inner_dim, inner_dim) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
# 2. Time embeddings
|
||||||
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||||
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||||
|
|
||||||
|
# 3. Define spatio-temporal transformers blocks
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
CogVideoXBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
norm_elementwise_affine=norm_elementwise_affine,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_projectors = None
|
||||||
|
self.relu = nn.LeakyReLU(negative_slope=0.01)
|
||||||
|
|
||||||
|
if out_proj_dim is not None:
|
||||||
|
self.out_projectors = nn.ModuleList(
|
||||||
|
[nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, enable=False, gradient_checkpointing_func=None):
|
||||||
|
self.gradient_checkpointing = enable
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
EF_Net_states: torch.Tensor,
|
||||||
|
timestep: Union[int, float, torch.LongTensor],
|
||||||
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
batch_size, num_frames, channels, height, width = EF_Net_states.shape
|
||||||
|
o_hidden_states = hidden_states
|
||||||
|
hidden_states = EF_Net_states
|
||||||
|
encoder_hidden_states_ = encoder_hidden_states
|
||||||
|
|
||||||
|
# 1. Time embedding
|
||||||
|
timesteps = timestep
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
|
||||||
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||||
|
hidden_states = self.embedding_dropout(hidden_states)
|
||||||
|
|
||||||
|
text_seq_length = encoder_hidden_states.shape[1]
|
||||||
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||||
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
|
|
||||||
|
o_hidden_states = self.patch_embed_first(encoder_hidden_states_, o_hidden_states)
|
||||||
|
o_hidden_states = self.embedding_dropout(o_hidden_states)
|
||||||
|
|
||||||
|
text_seq_length = encoder_hidden_states_.shape[1]
|
||||||
|
o_hidden_states = o_hidden_states[:, text_seq_length:]
|
||||||
|
|
||||||
|
EF_Net_hidden_states = ()
|
||||||
|
# 2. Transformer blocks
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
#if self.training and self.gradient_checkpointing:
|
||||||
|
if self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
|
emb,
|
||||||
|
image_rotary_emb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, encoder_hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
temb=emb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if self.out_projectors is not None:
|
||||||
|
coff = self.weights[i](hidden_states)
|
||||||
|
temp_list = []
|
||||||
|
for j in range(coff.shape[2]):
|
||||||
|
temp_list.append(hidden_states*coff[:,:,j:(j+1)])
|
||||||
|
out = torch.concat(temp_list, dim=1)
|
||||||
|
out = torch.concat([out, o_hidden_states], dim=2)
|
||||||
|
out = self.first_weights[i](out)
|
||||||
|
out = self.relu(out)
|
||||||
|
out = self.out_projectors[i](out)
|
||||||
|
EF_Net_hidden_states += (out,)
|
||||||
|
else:
|
||||||
|
out = torch.concat([weight*hidden_states for weight in self.weights], dim=1)
|
||||||
|
EF_Net_hidden_states += (out,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (EF_Net_hidden_states,)
|
||||||
|
return Transformer2DModelOutput(sample=EF_Net_hidden_states)
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user