mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
Compare commits
24 Commits
2871f9b9d5
...
b185059ba8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b185059ba8 | ||
|
|
fed499e971 | ||
|
|
f3dda43cdf | ||
|
|
126322139f | ||
|
|
90e8367f5e | ||
|
|
76f7930d07 | ||
|
|
3d2ee02d83 | ||
|
|
51daeef1b7 | ||
|
|
5bca0548d9 | ||
|
|
97b7b18f35 | ||
|
|
f5454aa806 | ||
|
|
3a38d01414 | ||
|
|
8c5e4f812d | ||
|
|
eaaa0f6e1a | ||
|
|
25d0ede406 | ||
|
|
f16d38a5d2 | ||
|
|
fcc0f3e65a | ||
|
|
0758d2d016 | ||
|
|
b5eefbf4d4 | ||
|
|
795f8b0565 | ||
|
|
d9d30f24bb | ||
|
|
729a6485ea | ||
|
|
411791c748 | ||
|
|
7a10e732bb |
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@ -35,6 +35,9 @@ from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
from .embeddings import CogVideoXPatchEmbed
|
||||
|
||||
from .enhance_a_video.enhance import get_feta_scores
|
||||
from .enhance_a_video.globals import is_enhance_enabled, set_num_frames
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -60,27 +63,28 @@ def set_attention_func(attention_mode, heads):
|
||||
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||
@torch.compiler.disable()
|
||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
||||
return sageattn(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask)
|
||||
return func
|
||||
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||
@torch.compiler.disable()
|
||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
|
||||
return func
|
||||
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
|
||||
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
||||
@torch.compiler.disable()
|
||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
||||
return sageattn_qk_int8_pv_fp16_triton(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask)
|
||||
return func
|
||||
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
|
||||
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||
@torch.compiler.disable()
|
||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
||||
return func
|
||||
|
||||
#for fastercache
|
||||
def fft(tensor):
|
||||
tensor_fft = torch.fft.fft2(tensor)
|
||||
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||
@ -98,6 +102,13 @@ def fft(tensor):
|
||||
|
||||
return low_freq_fft, high_freq_fft
|
||||
|
||||
#for teacache
|
||||
def poly1d(coefficients, x):
|
||||
result = torch.zeros_like(x)
|
||||
for i, coeff in enumerate(coefficients):
|
||||
result += coeff * (x ** (len(coefficients) - 1 - i))
|
||||
return result.abs()
|
||||
|
||||
#region Attention
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
@ -159,6 +170,10 @@ class CogVideoXAttnProcessor2_0:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
#feta
|
||||
if is_enhance_enabled():
|
||||
feta_scores = get_feta_scores(attn, query, key, head_dim, text_seq_length)
|
||||
|
||||
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
|
||||
|
||||
@ -173,6 +188,10 @@ class CogVideoXAttnProcessor2_0:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
|
||||
if is_enhance_enabled():
|
||||
hidden_states *= feta_scores
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
#region Blocks
|
||||
@ -515,7 +534,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
#tora
|
||||
self.fuser_list = None
|
||||
|
||||
#fastercache
|
||||
self.use_fastercache = False
|
||||
self.fastercache_counter = 0
|
||||
self.fastercache_start_step = 15
|
||||
@ -523,7 +547,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
#teacache
|
||||
self.use_teacache = False
|
||||
self.teacache_rel_l1_thresh = 0.0
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
#CogVideoX-2B
|
||||
self.teacache_coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
|
||||
else:
|
||||
#CogVideoX-5B
|
||||
self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
|
||||
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
@ -543,6 +576,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
set_num_frames(num_frames) #enhance a video global
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
@ -649,33 +684,56 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
output = torch.cat([output, recovered_uncond])
|
||||
else:
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
#has_nan = torch.isnan(hidden_states).any()
|
||||
#if has_nan:
|
||||
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
||||
if self.use_teacache:
|
||||
if not hasattr(self, 'accumulated_rel_l1_distance'):
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
self.accumulated_rel_l1_distance += poly1d(self.teacache_coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()))
|
||||
if self.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
|
||||
should_calc = False
|
||||
self.teacache_counter += 1
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
#print("self.accumulated_rel_l1_distance ", self.accumulated_rel_l1_distance)
|
||||
self.previous_modulated_input = emb
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
encoder_hidden_states += self.previous_residual_encoder
|
||||
|
||||
if not self.use_teacache or (self.use_teacache and should_calc):
|
||||
if self.use_teacache:
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
ori_encoder_hidden_states = encoder_hidden_states.clone()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
|
||||
#controlnet
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
print(controlnet_block_weight)
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
#controlnet
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
print(controlnet_block_weight)
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
|
||||
if self.use_teacache:
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
@ -718,4 +776,4 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
0
enhance_a_video/__init__.py
Normal file
0
enhance_a_video/__init__.py
Normal file
82
enhance_a_video/enhance.py
Normal file
82
enhance_a_video/enhance.py
Normal file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from diffusers.models.attention import Attention
|
||||
from .globals import get_enhance_weight, get_num_frames
|
||||
|
||||
# def get_feta_scores(query, key):
|
||||
# img_q, img_k = query, key
|
||||
|
||||
# num_frames = get_num_frames()
|
||||
|
||||
# B, S, N, C = img_q.shape
|
||||
|
||||
# # Calculate spatial dimension
|
||||
# spatial_dim = S // num_frames
|
||||
|
||||
# # Add time dimension between spatial and head dims
|
||||
# query_image = img_q.reshape(B, spatial_dim, num_frames, N, C)
|
||||
# key_image = img_k.reshape(B, spatial_dim, num_frames, N, C)
|
||||
|
||||
# # Expand time dimension
|
||||
# query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
|
||||
# key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
|
||||
|
||||
# # Reshape to match feta_score input format: [(B S) N T C]
|
||||
# query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128])
|
||||
# key_image = rearrange(key_image, "b s t n c -> (b s) n t c")
|
||||
|
||||
# return feta_score(query_image, key_image, C, num_frames)
|
||||
|
||||
def get_feta_scores(
|
||||
attn: Attention,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_dim: int,
|
||||
text_seq_length: int,
|
||||
) -> torch.Tensor:
|
||||
num_frames = get_num_frames()
|
||||
spatial_dim = int((query.shape[2] - text_seq_length) / num_frames)
|
||||
|
||||
query_image = rearrange(
|
||||
query[:, :, text_seq_length:],
|
||||
"B N (T S) C -> (B S) N T C",
|
||||
N=attn.heads,
|
||||
T=num_frames,
|
||||
S=spatial_dim,
|
||||
C=head_dim,
|
||||
)
|
||||
key_image = rearrange(
|
||||
key[:, :, text_seq_length:],
|
||||
"B N (T S) C -> (B S) N T C",
|
||||
N=attn.heads,
|
||||
T=num_frames,
|
||||
S=spatial_dim,
|
||||
C=head_dim,
|
||||
)
|
||||
return feta_score(query_image, key_image, head_dim, num_frames)
|
||||
|
||||
def feta_score(query_image, key_image, head_dim, num_frames):
|
||||
scale = head_dim**-0.5
|
||||
query_image = query_image * scale
|
||||
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
|
||||
attn_temp = attn_temp.to(torch.float32)
|
||||
attn_temp = attn_temp.softmax(dim=-1)
|
||||
|
||||
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
|
||||
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
||||
|
||||
# Create a mask for diagonal elements
|
||||
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
||||
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
||||
|
||||
# Zero out diagonal elements
|
||||
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
||||
|
||||
# Calculate mean for each token's attention matrix
|
||||
# Number of off-diagonal elements per matrix is n*n - n
|
||||
num_off_diag = num_frames * num_frames - num_frames
|
||||
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
||||
|
||||
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
|
||||
enhance_scores = enhance_scores.clamp(min=1)
|
||||
return enhance_scores
|
||||
31
enhance_a_video/globals.py
Normal file
31
enhance_a_video/globals.py
Normal file
@ -0,0 +1,31 @@
|
||||
NUM_FRAMES = None
|
||||
FETA_WEIGHT = None
|
||||
ENABLE_FETA = False
|
||||
|
||||
def set_num_frames(num_frames: int):
|
||||
global NUM_FRAMES
|
||||
NUM_FRAMES = num_frames
|
||||
|
||||
|
||||
def get_num_frames() -> int:
|
||||
return NUM_FRAMES
|
||||
|
||||
|
||||
def enable_enhance():
|
||||
global ENABLE_FETA
|
||||
ENABLE_FETA = True
|
||||
|
||||
def disable_enhance():
|
||||
global ENABLE_FETA
|
||||
ENABLE_FETA = False
|
||||
|
||||
def is_enhance_enabled() -> bool:
|
||||
return ENABLE_FETA
|
||||
|
||||
def set_enhance_weight(feta_weight: float):
|
||||
global FETA_WEIGHT
|
||||
FETA_WEIGHT = feta_weight
|
||||
|
||||
|
||||
def get_enhance_weight() -> float:
|
||||
return FETA_WEIGHT
|
||||
1291
example_workflows/cogvideox_1_0_5b_I2V_noise_warp_01.json
Normal file
1291
example_workflows/cogvideox_1_0_5b_I2V_noise_warp_01.json
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
454
example_workflows/cut_and_drag_for_noisewarp_01.json
Normal file
454
example_workflows/cut_and_drag_for_noisewarp_01.json
Normal file
File diff suppressed because one or more lines are too long
BIN
example_workflows/noise_warp_example_input_video.mp4
Normal file
BIN
example_workflows/noise_warp_example_input_video.mp4
Normal file
Binary file not shown.
@ -1,79 +0,0 @@
|
||||
import io
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import struct
|
||||
import numpy as np
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from comfy.taesd.taesd import TAESD
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
|
||||
def preview_to_image(latent_image):
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
class LatentPreviewer:
|
||||
def decode_latent_to_preview(self, x0):
|
||||
pass
|
||||
|
||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||
preview_image = self.decode_latent_to_preview(x0)
|
||||
return ("GIF", preview_image, MAX_PREVIEW_RESOLUTION)
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self):
|
||||
latent_rgb_factors = [[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
|
||||
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||
self.latent_rgb_factors_bias = None
|
||||
# if latent_rgb_factors_bias is not None:
|
||||
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
|
||||
bias=self.latent_rgb_factors_bias)
|
||||
return preview_to_image(latent_image)
|
||||
|
||||
|
||||
def get_previewer():
|
||||
previewer = None
|
||||
method = args.preview_method
|
||||
if method != LatentPreviewMethod.NoPreviews:
|
||||
# TODO previewer method
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
if previewer is None:
|
||||
previewer = Latent2RGBPreviewer()
|
||||
return previewer
|
||||
|
||||
def prepare_callback(model, steps, x0_output_dict=None):
|
||||
preview_format = "JPEG"
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
previewer = get_previewer()
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
if x0_output_dict is not None:
|
||||
x0_output_dict["x0"] = x0
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
return callback
|
||||
|
||||
@ -70,6 +70,7 @@ class CogVideoLoraSelect:
|
||||
RETURN_NAMES = ("lora", )
|
||||
FUNCTION = "getlorapath"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "Select a LoRA model from ComfyUI/models/CogVideo/loras"
|
||||
|
||||
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
||||
cog_loras_list = []
|
||||
@ -86,6 +87,43 @@ class CogVideoLoraSelect:
|
||||
cog_loras_list.append(cog_lora)
|
||||
print(cog_loras_list)
|
||||
return (cog_loras_list,)
|
||||
|
||||
class CogVideoLoraSelectComfy:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora": (folder_paths.get_filename_list("loras"),
|
||||
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||
},
|
||||
"optional": {
|
||||
"prev_lora":("COGLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
|
||||
"fuse_lora": ("BOOLEAN", {"default": False, "tooltip": "Fuse the LoRA weights into the transformer"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COGLORA",)
|
||||
RETURN_NAMES = ("lora", )
|
||||
FUNCTION = "getlorapath"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras"
|
||||
|
||||
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
||||
cog_loras_list = []
|
||||
|
||||
cog_lora = {
|
||||
"path": folder_paths.get_full_path("loras", lora),
|
||||
"strength": strength,
|
||||
"name": lora.split(".")[0],
|
||||
"fuse_lora": fuse_lora
|
||||
}
|
||||
if prev_lora is not None:
|
||||
cog_loras_list.extend(prev_lora)
|
||||
|
||||
cog_loras_list.append(cog_lora)
|
||||
print(cog_loras_list)
|
||||
return (cog_loras_list,)
|
||||
|
||||
#region DownloadAndLoadCogVideoModel
|
||||
class DownloadAndLoadCogVideoModel:
|
||||
@ -109,6 +147,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose",
|
||||
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose",
|
||||
"alibaba-pai/CogVideoX-Fun-V1.1-5b-Control",
|
||||
"alibaba-pai/CogVideoX-Fun-V1.5-5b-InP",
|
||||
"feizhengcong/CogvideoX-Interpolation",
|
||||
"NimVideo/cogvideox-2b-img2vid"
|
||||
],
|
||||
@ -177,7 +216,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
download_path = folder_paths.get_folder_paths("CogVideo")[0]
|
||||
|
||||
if "Fun" in model:
|
||||
if not "1.1" in model:
|
||||
if "1.1" not in model and "1.5" not in model:
|
||||
repo_id = "kijai/CogVideoX-Fun-pruned"
|
||||
if "2b" in model:
|
||||
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-2b-InP") # location of the official model
|
||||
@ -187,7 +226,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", "CogVideoX-Fun-5b-InP") # location of the official model
|
||||
if not os.path.exists(base_path):
|
||||
base_path = os.path.join(download_path, "CogVideoX-Fun-5b-InP")
|
||||
elif "1.1" in model:
|
||||
else:
|
||||
repo_id = model
|
||||
base_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", (model.split("/")[-1])) # location of the official model
|
||||
if not os.path.exists(base_path):
|
||||
@ -240,7 +279,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
|
||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||
|
||||
if "1.5" in model:
|
||||
if "1.5" in model and not "fun" in model:
|
||||
transformer.config.sample_height = 300
|
||||
transformer.config.sample_width = 300
|
||||
|
||||
@ -295,6 +334,8 @@ class DownloadAndLoadCogVideoModel:
|
||||
pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd)
|
||||
except:
|
||||
raise ValueError(f"Can't recognize LoRA {l['path']}")
|
||||
del lora_sd
|
||||
mm.soft_empty_cache()
|
||||
if adapter_list:
|
||||
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
|
||||
if fuse:
|
||||
@ -302,6 +343,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
if dimensionx_lora:
|
||||
lora_scale = lora_scale / lora_rank
|
||||
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
|
||||
pipe.delete_adapters(adapter_list)
|
||||
|
||||
|
||||
if "fused" in attention_mode:
|
||||
@ -660,7 +702,7 @@ class CogVideoXModelLoader:
|
||||
|
||||
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
|
||||
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
|
||||
|
||||
transformer = None
|
||||
if "sage" in attention_mode:
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
@ -689,6 +731,8 @@ class CogVideoXModelLoader:
|
||||
model_type = "5b_I2V_1_5"
|
||||
elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2):
|
||||
model_type = "fun_2b"
|
||||
elif sd["patch_embed.proj.weight"].shape == (1920, 32, 2, 2):
|
||||
model_type = "cogvideox-2b-img2vid"
|
||||
elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2):
|
||||
model_type = "2b"
|
||||
elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2):
|
||||
@ -710,7 +754,7 @@ class CogVideoXModelLoader:
|
||||
with open(transformer_config_path) as f:
|
||||
transformer_config = json.load(f)
|
||||
|
||||
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
|
||||
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5", "cogvideox-2b-img2vid"]:
|
||||
transformer_config["in_channels"] = 32
|
||||
if "1_5" in model_type:
|
||||
transformer_config["ofs_embed_dim"] = 512
|
||||
@ -736,6 +780,10 @@ class CogVideoXModelLoader:
|
||||
#dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||
set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name])
|
||||
del sd
|
||||
# TODO fix for transformer model patch_embed.pos_embedding dtype
|
||||
# or at add line ComfyUI-CogVideoXWrapper/embeddings.py:129 code
|
||||
# pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
||||
transformer = transformer.to(base_dtype).to(transformer_load_device)
|
||||
|
||||
#scheduler
|
||||
with open(scheduler_config_path) as f:
|
||||
@ -759,7 +807,8 @@ class CogVideoXModelLoader:
|
||||
dtype=base_dtype,
|
||||
is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower())
|
||||
)
|
||||
|
||||
if "cogvideox-2b-img2vid" == model_type:
|
||||
pipe.input_with_padding = False
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
@ -925,6 +974,7 @@ class DownloadAndLoadToraModel:
|
||||
"model": (
|
||||
[
|
||||
"kijai/CogVideoX-5b-Tora",
|
||||
"kijai/CogVideoX-5b-Tora-I2V",
|
||||
],
|
||||
),
|
||||
},
|
||||
@ -954,14 +1004,17 @@ class DownloadAndLoadToraModel:
|
||||
pass
|
||||
|
||||
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
|
||||
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
|
||||
|
||||
|
||||
fuser_model = "fuser.safetensors" if not "I2V" in model else "fuser_I2V.safetensors"
|
||||
fuser_path = os.path.join(download_path, "fuser", fuser_model)
|
||||
if not os.path.exists(fuser_path):
|
||||
log.info(f"Downloading Fuser model to: {fuser_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id=model,
|
||||
allow_patterns=["*fuser.safetensors*"],
|
||||
allow_patterns=[fuser_model],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
@ -983,14 +1036,15 @@ class DownloadAndLoadToraModel:
|
||||
param.data = param.data.to(torch.bfloat16).to(device)
|
||||
del fuser_sd
|
||||
|
||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
|
||||
traj_extractor_model = "traj_extractor.safetensors" if not "I2V" in model else "traj_extractor_I2V.safetensors"
|
||||
traj_extractor_path = os.path.join(download_path, "traj_extractor", traj_extractor_model)
|
||||
if not os.path.exists(traj_extractor_path):
|
||||
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="kijai/CogVideoX-5b-Tora",
|
||||
allow_patterns=["*traj_extractor.safetensors*"],
|
||||
allow_patterns=[traj_extractor_model],
|
||||
local_dir=download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
@ -1078,6 +1132,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||
"CogVideoXVAELoader": CogVideoXVAELoader,
|
||||
"CogVideoXModelLoader": CogVideoXModelLoader,
|
||||
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -1087,4 +1142,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
||||
"CogVideoXModelLoader": "CogVideoX Model Loader",
|
||||
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy"
|
||||
}
|
||||
98
nodes.py
98
nodes.py
@ -49,6 +49,25 @@ if not "CogVideo" in folder_paths.folder_names_and_paths:
|
||||
if not "cogvideox_loras" in folder_paths.folder_names_and_paths:
|
||||
folder_paths.add_model_folder_path("cogvideox_loras", os.path.join(folder_paths.models_dir, "CogVideo", "loras"))
|
||||
|
||||
class CogVideoEnhanceAVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"weight": ("FLOAT", {"default": 1.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("FETAARGS",)
|
||||
RETURN_NAMES = ("feta_args",)
|
||||
FUNCTION = "setargs"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"
|
||||
|
||||
def setargs(self, **kwargs):
|
||||
return (kwargs, )
|
||||
|
||||
class CogVideoContextOptions:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -263,13 +282,14 @@ class CogVideoImageEncode:
|
||||
start_latents = vae.encode(start_image).latent_dist.sample(generator)
|
||||
start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
|
||||
|
||||
if end_image is not None:
|
||||
end_image = (end_image * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
||||
if noise_aug_strength > 0:
|
||||
end_image = add_noise_to_reference_video(end_image, ratio=noise_aug_strength)
|
||||
end_latents = vae.encode(end_image).latent_dist.sample(generator)
|
||||
end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
latents_list.append(end_latents)
|
||||
latents_list = [start_latents, end_latents]
|
||||
final_latents = torch.cat(latents_list, dim=1)
|
||||
else:
|
||||
final_latents = start_latents
|
||||
@ -284,32 +304,6 @@ class CogVideoImageEncode:
|
||||
"start_percent": start_percent,
|
||||
"end_percent": end_percent
|
||||
}, )
|
||||
|
||||
class CogVideoConcatLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"samples_to": ("LATENT", ),
|
||||
"samples_from": ("LATENT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
RETURN_NAMES = ("samples",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, samples_from, samples_to):
|
||||
|
||||
insert_from = samples_from["samples"].clone()
|
||||
insert_to = samples_to["samples"].clone()
|
||||
new_latents = torch.cat((insert_to, insert_from), dim=1)
|
||||
print("new latents shape: ", new_latents.shape)
|
||||
return ({
|
||||
"samples": new_latents,
|
||||
"start_percent": samples_from["start_percent"],
|
||||
"end_percent": samples_from["end_percent"]
|
||||
}, )
|
||||
|
||||
class CogVideoImageEncodeFunInP:
|
||||
@classmethod
|
||||
@ -385,8 +379,8 @@ class CogVideoImageEncodeFunInP:
|
||||
masked_image_latents = masked_image_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W
|
||||
|
||||
mask = torch.zeros_like(masked_image_latents[:, :, :1, :, :])
|
||||
if end_image is not None:
|
||||
mask[:, -1, :, :, :] = 0
|
||||
#if end_image is not None:
|
||||
# mask[:, -1, :, :, :] = 0
|
||||
mask[:, 0, :, :, :] = vae_scaling_factor
|
||||
|
||||
final_latents = masked_image_latents * vae_scaling_factor
|
||||
@ -590,6 +584,26 @@ class CogVideoXFasterCache:
|
||||
"num_blocks_to_cache" : num_blocks_to_cache,
|
||||
}
|
||||
return (fastercache,)
|
||||
|
||||
class CogVideoXTeaCache:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Cache threshold, higher values are faster while sacrificing quality"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TEACACHEARGS",)
|
||||
RETURN_NAMES = ("teacache_args",)
|
||||
FUNCTION = "args"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def args(self, rel_l1_thresh):
|
||||
teacache = {
|
||||
"rel_l1_thresh": rel_l1_thresh
|
||||
}
|
||||
return (teacache,)
|
||||
|
||||
#region Sampler
|
||||
class CogVideoSampler:
|
||||
@ -617,6 +631,8 @@ class CogVideoSampler:
|
||||
"controlnet": ("COGVIDECONTROLNET",),
|
||||
"tora_trajectory": ("TORAFEATURES", ),
|
||||
"fastercache": ("FASTERCACHEARGS", ),
|
||||
"feta_args": ("FETAARGS", ),
|
||||
"teacache_args": ("TEACACHEARGS", ),
|
||||
}
|
||||
}
|
||||
|
||||
@ -626,7 +642,7 @@ class CogVideoSampler:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None):
|
||||
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -648,7 +664,7 @@ class CogVideoSampler:
|
||||
image_conds = image_cond_latents["samples"]
|
||||
image_cond_start_percent = image_cond_latents.get("start_percent", 0.0)
|
||||
image_cond_end_percent = image_cond_latents.get("end_percent", 1.0)
|
||||
if "1.5" in model_name or "1_5" in model_name:
|
||||
if ("1.5" in model_name or "1_5" in model_name) and not "fun" in model_name.lower():
|
||||
image_conds = image_conds / 0.7 # needed for 1.5 models
|
||||
else:
|
||||
if not "fun" in model_name.lower():
|
||||
@ -711,6 +727,13 @@ class CogVideoSampler:
|
||||
pipe.transformer.use_fastercache = False
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
|
||||
if teacache_args is not None:
|
||||
pipe.transformer.use_teacache = True
|
||||
pipe.transformer.teacache_rel_l1_thresh = teacache_args["rel_l1_thresh"]
|
||||
log.info(f"TeaCache enabled with rel_l1_thresh: {pipe.transformer.teacache_rel_l1_thresh}")
|
||||
else:
|
||||
pipe.transformer.use_teacache = False
|
||||
|
||||
if not isinstance(cfg, list):
|
||||
cfg = [cfg for _ in range(steps)]
|
||||
else:
|
||||
@ -747,6 +770,7 @@ class CogVideoSampler:
|
||||
tora=tora_trajectory if tora_trajectory is not None else None,
|
||||
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
|
||||
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
|
||||
feta_args=feta_args,
|
||||
)
|
||||
if not model["cpu_offloading"] and model["manual_offloading"]:
|
||||
pipe.transformer.to(offload_device)
|
||||
@ -758,6 +782,9 @@ class CogVideoSampler:
|
||||
block.cached_encoder_hidden_states = None
|
||||
|
||||
print_memory(device)
|
||||
|
||||
if teacache_args is not None:
|
||||
log.info(f"TeaCache skipped steps: {pipe.transformer.teacache_counter}")
|
||||
mm.soft_empty_cache()
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
@ -936,7 +963,8 @@ class CogVideoLatentPreview:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
|
||||
#[[0.0658900170023352, 0.04687556512203313, -0.056971557475649186], [-0.01265770449940036, -0.02814809569100843, -0.0768912512529372], [0.061456544746314665, 0.0005511617552452358, -0.0652574975291287], [-0.09020669168815276, -0.004755440180558637, -0.023763970904494294], [0.031766964513999865, -0.030959599938418375, 0.08654669098083616], [-0.005981764690055846, -0.08809119252349802, -0.06439852368217663], [-0.0212114426433989, 0.08894281999597677, 0.05155629477559985], [-0.013947446911030725, -0.08987475069900677, -0.08923124751217484], [-0.08235967967978511, 0.07268025379974379, 0.08830486164536037], [-0.08052049179735378, -0.050116143175332195, 0.02023752569687405], [-0.07607527759162447, 0.06827156419895981, 0.08678111754261035], [-0.04689089232553825, 0.017294986041038893, -0.10280492336438908], [-0.06105783150270304, 0.07311850680875913, 0.019995735372550075], [-0.09232589996527711, -0.012869815059053047, -0.04355587834255975], [-0.06679931010802251, 0.018399815879067458, 0.06802404982033876], [-0.013062632927118165, -0.04292991477896661, 0.07476243356192845]]
|
||||
latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
|
||||
#latent_rgb_factors =[[0.11945946736445662, 0.09919175788574555, -0.004832707433877734], [-0.0011977028264356232, 0.05496505130267682, 0.021321622433638193], [-0.014088548986590666, -0.008701477861945644, -0.020991313281459367], [0.03063921972519621, 0.12186477097625073, 0.0139593690235148], [0.0927403067854673, 0.030293187650929136, 0.05083134241694003], [0.0379112441305742, 0.04935199882777209, 0.058562766246777774], [0.017749911959153715, 0.008839453404921545, 0.036005638019226294], [0.10610119248526109, 0.02339855688237826, 0.057154257614084596], [0.1273639464837117, -0.010959856130713416, 0.043268631260428896], [-0.01873510946881321, 0.08220930648486932, 0.10613256772247093], [0.008429116376722327, 0.07623856561000408, 0.09295712117576727], [0.12938137079617007, 0.12360403483892413, 0.04478930933220116], [0.04565908794779364, 0.041064156741596365, -0.017695041535528512], [0.00019003240570281826, -0.013965147883381978, 0.05329669529635849], [0.08082391586738358, 0.11548306825496074, -0.021464170006615893], [-0.01517932393230994, -0.0057985555313003236, 0.07216646476618871]]
|
||||
latent_rgb_factors = [[0.03197404301362048, 0.04091260743347359, 0.0015679806301828524], [0.005517101026578029, 0.0052348639043457755, -0.005613441650464035], [0.0012485338264583965, -0.016096744206117782, 0.025023940031635054], [0.01760126794276171, 0.0036818415416642893, -0.0006019202528157255], [0.000444954842288864, 0.006102128982092191, 0.0008457999272962447], [-0.010531904354560697, -0.0032275501924977175, -0.00886595780267917], [-0.0001454543946122991, 0.010199210750845965, -0.00012702234832386188], [0.02078497279904325, -0.001669617778939972, 0.006712703698951264], [0.005529571599763264, 0.009733929789086743, 0.001887302765339838], [0.012138415094654218, 0.024684961927224837, 0.037211249767461915], [0.0010364484570000384, 0.01983636315929172, 0.009864602025627755], [0.006802862648143341, -0.0010509255113510681, -0.007026003345126021], [0.0003532208468418043, 0.005351971582801936, -0.01845912126717106], [-0.009045079994694397, -0.01127941143183089, 0.0042294057970470806], [0.002548289972720752, 0.025224244654428216, -0.0006086130121693347], [-0.011135669222532816, 0.0018181308593668505, 0.02794541485349922]]
|
||||
import random
|
||||
random.seed(seed)
|
||||
latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(16)]
|
||||
@ -985,7 +1013,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoLatentPreview": CogVideoLatentPreview,
|
||||
"CogVideoXTorchCompileSettings": CogVideoXTorchCompileSettings,
|
||||
"CogVideoImageEncodeFunInP": CogVideoImageEncodeFunInP,
|
||||
"CogVideoConcatLatent": CogVideoConcatLatent,
|
||||
"CogVideoEnhanceAVideo": CogVideoEnhanceAVideo,
|
||||
"CogVideoXTeaCache": CogVideoXTeaCache,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoSampler": "CogVideo Sampler",
|
||||
@ -1002,5 +1031,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoLatentPreview": "CogVideo LatentPreview",
|
||||
"CogVideoXTorchCompileSettings": "CogVideo TorchCompileSettings",
|
||||
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
|
||||
"CogVideoConcatLatent": "CogVideo Concat Latent",
|
||||
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
|
||||
"CogVideoXTeaCache": "CogVideoX TeaCache",
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@ from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
|
||||
from .embeddings import get_3d_rotary_pos_embed
|
||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight
|
||||
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
@ -110,6 +111,34 @@ def retrieve_timesteps(
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
class CogVideoXLatentFormat():
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
scale_factor = 0.7
|
||||
taesd_decoder_name = None
|
||||
|
||||
latent_rgb_factors = [[0.03197404301362048, 0.04091260743347359, 0.0015679806301828524],
|
||||
[0.005517101026578029, 0.0052348639043457755, -0.005613441650464035],
|
||||
[0.0012485338264583965, -0.016096744206117782, 0.025023940031635054],
|
||||
[0.01760126794276171, 0.0036818415416642893, -0.0006019202528157255],
|
||||
[0.000444954842288864, 0.006102128982092191, 0.0008457999272962447],
|
||||
[-0.010531904354560697, -0.0032275501924977175, -0.00886595780267917],
|
||||
[-0.0001454543946122991, 0.010199210750845965, -0.00012702234832386188],
|
||||
[0.02078497279904325, -0.001669617778939972, 0.006712703698951264],
|
||||
[0.005529571599763264, 0.009733929789086743, 0.001887302765339838],
|
||||
[0.012138415094654218, 0.024684961927224837, 0.037211249767461915],
|
||||
[0.0010364484570000384, 0.01983636315929172, 0.009864602025627755],
|
||||
[0.006802862648143341, -0.0010509255113510681, -0.007026003345126021],
|
||||
[0.0003532208468418043, 0.005351971582801936, -0.01845912126717106],
|
||||
[-0.009045079994694397, -0.01127941143183089, 0.0042294057970470806],
|
||||
[0.002548289972720752, 0.025224244654428216, -0.0006086130121693347],
|
||||
[-0.011135669222532816, 0.0018181308593668505, 0.02794541485349922]]
|
||||
latent_rgb_factors_bias = [ -0.023, 0.0, -0.017]
|
||||
|
||||
class CogVideoXModelPlaceholder():
|
||||
def __init__(self):
|
||||
self.latent_format = CogVideoXLatentFormat
|
||||
|
||||
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
@ -195,7 +224,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
noise[:, place_idx:place_idx + delta, :, :, :] = noise[:, list_idx, :, :, :]
|
||||
if latents is None:
|
||||
latents = noise.to(device)
|
||||
else:
|
||||
elif denoise_strength < 1.0:
|
||||
latents = latents.to(device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||
latent_timestep = timesteps[:1]
|
||||
@ -212,6 +241,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
latents = latents[:, :frames_needed, :, :, :]
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise.to(device), latent_timestep)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
latents = latents * self.scheduler.init_noise_sigma # scale the initial noise by the standard deviation required by the scheduler
|
||||
return latents, timesteps
|
||||
|
||||
@ -351,6 +382,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
tora: Optional[dict] = None,
|
||||
image_cond_start_percent: float = 0.0,
|
||||
image_cond_end_percent: float = 1.0,
|
||||
feta_args: Optional[dict] = None,
|
||||
|
||||
):
|
||||
"""
|
||||
@ -471,50 +503,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
# 5.5.
|
||||
if image_cond_latents is not None:
|
||||
if image_cond_latents.shape[1] == 3:
|
||||
logger.info("More than one image conditioning frame received, interpolating")
|
||||
total_padding = latents.shape[1] - 3
|
||||
half_padding = total_padding // 2
|
||||
|
||||
padding_shape = (
|
||||
batch_size,
|
||||
half_padding,
|
||||
self.vae_latent_channels,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||
middle_frame = image_cond_latents[:, 1, :, :, :].unsqueeze(1)
|
||||
|
||||
image_cond_latents = torch.cat([
|
||||
image_cond_latents[:, 0, :, :, :].unsqueeze(1),
|
||||
latent_padding,
|
||||
middle_frame,
|
||||
latent_padding,
|
||||
image_cond_latents[:, -1, :, :, :].unsqueeze(1)
|
||||
], dim=1)
|
||||
|
||||
# If total_padding is odd, add one more padding after the middle frame
|
||||
if total_padding % 2 != 0:
|
||||
extra_padding = torch.zeros(
|
||||
(batch_size, 1, self.vae_latent_channels,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial),
|
||||
device=device, dtype=self.vae_dtype
|
||||
)
|
||||
image_cond_latents = torch.cat([image_cond_latents, extra_padding], dim=1)
|
||||
|
||||
if self.transformer.config.patch_size_t is not None:
|
||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||
|
||||
middle_frame_idx = image_cond_latents.shape[1] // 2
|
||||
print("middle_frame_idx", middle_frame_idx)
|
||||
print(middle_frame.shape)
|
||||
print(image_cond_latents.shape)
|
||||
|
||||
|
||||
elif image_cond_latents.shape[1] == 2:
|
||||
image_cond_frame_count = image_cond_latents.size(1)
|
||||
patch_size_t = self.transformer.config.patch_size_t
|
||||
if image_cond_frame_count == 2:
|
||||
logger.info("More than one image conditioning frame received, interpolating")
|
||||
padding_shape = (
|
||||
batch_size,
|
||||
@ -525,12 +516,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
|
||||
if self.transformer.config.patch_size_t is not None:
|
||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
||||
if patch_size_t:
|
||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % patch_size_t, ...]
|
||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||
|
||||
logger.info(f"image cond latents shape: {image_cond_latents.shape}")
|
||||
elif image_cond_latents.shape[1] == 1:
|
||||
elif image_cond_frame_count == 1:
|
||||
logger.info("Only one image conditioning frame received, img2vid")
|
||||
if self.input_with_padding:
|
||||
padding_shape = (
|
||||
@ -543,13 +534,20 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
|
||||
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
|
||||
# Select the first frame along the second dimension
|
||||
if self.transformer.config.patch_size_t is not None:
|
||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
||||
if patch_size_t:
|
||||
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % patch_size_t, ...]
|
||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||
else:
|
||||
image_cond_latents = image_cond_latents.repeat(1, latents.shape[1], 1, 1, 1)
|
||||
else:
|
||||
logger.info(f"Received {image_cond_latents.shape[1]} image conditioning frames")
|
||||
if fun_mask is not None and patch_size_t:
|
||||
logger.info(f"1.5 model received {fun_mask.shape[1]} masks")
|
||||
first_frame = image_cond_latents[:, : image_cond_frame_count % patch_size_t, ...]
|
||||
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
|
||||
fun_mask_first_frame = fun_mask[:, : image_cond_frame_count % patch_size_t, ...]
|
||||
fun_mask = torch.cat([fun_mask_first_frame, fun_mask], dim=1)
|
||||
fun_mask[:, 1:, ...] = 0
|
||||
image_cond_latents = image_cond_latents.to(self.vae_dtype)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
@ -607,7 +605,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
else:
|
||||
controlnet_states = None
|
||||
control_weights= None
|
||||
|
||||
# 9. Tora
|
||||
if tora is not None:
|
||||
trajectory_length = tora["video_flow_features"].shape[1]
|
||||
logger.info(f"Tora trajectory length: {trajectory_length}")
|
||||
@ -619,16 +617,41 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||
|
||||
from .latent_preview import prepare_callback
|
||||
callback = prepare_callback(self.transformer, num_inference_steps)
|
||||
if feta_args is not None:
|
||||
set_enhance_weight(feta_args["weight"])
|
||||
feta_start_percent = feta_args["start_percent"]
|
||||
feta_end_percent = feta_args["end_percent"]
|
||||
enable_enhance()
|
||||
else:
|
||||
disable_enhance()
|
||||
|
||||
# reset TeaCache
|
||||
if hasattr(self.transformer, 'accumulated_rel_l1_distance'):
|
||||
delattr(self.transformer, 'accumulated_rel_l1_distance')
|
||||
self.transformer.teacache_counter = 0
|
||||
|
||||
# 11. Denoising loop
|
||||
#from .latent_preview import prepare_callback
|
||||
#callback = prepare_callback(self.transformer, num_inference_steps)
|
||||
from latent_preview import prepare_callback
|
||||
self.model = CogVideoXModelPlaceholder()
|
||||
self.load_device = device
|
||||
callback = prepare_callback(self, num_inference_steps)
|
||||
|
||||
# 9. Denoising loop
|
||||
comfy_pbar = ProgressBar(len(timesteps))
|
||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||
old_pred_original_sample = None # for DPM-solver++
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
if feta_args is not None:
|
||||
if feta_start_percent <= current_step_percentage <= feta_end_percent:
|
||||
enable_enhance()
|
||||
else:
|
||||
disable_enhance()
|
||||
# region context schedule sampling
|
||||
if use_context_schedule:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@ -636,31 +659,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
counter = torch.zeros_like(latent_model_input)
|
||||
noise_pred = torch.zeros_like(latent_model_input)
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
if image_cond_latents is not None:
|
||||
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
||||
latent_image_input = torch.zeros_like(latent_model_input)
|
||||
else:
|
||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||
if fun_mask is not None: #for fun img2vid and interpolation
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
|
||||
masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
|
||||
latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
|
||||
else:
|
||||
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||
else: # for Fun inpaint vid2vid
|
||||
if fun_mask is not None:
|
||||
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
|
||||
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
|
||||
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
|
||||
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
|
||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
# use same rotary embeddings for all context windows
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
|
||||
@ -770,8 +775,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_step_percentage = i / num_inference_steps
|
||||
|
||||
if image_cond_latents is not None:
|
||||
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
|
||||
latent_image_input = torch.zeros_like(latent_model_input)
|
||||
@ -849,8 +852,11 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None:
|
||||
callback(i, latents.detach()[-1], None, num_inference_steps)
|
||||
if callback is not None:
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
callback_tensor = (alpha_prod_t**0.5) * latent_model_input[0][:, :16, :, :] - (beta_prod_t**0.5) * noise_pred.detach()[0]
|
||||
callback(i, callback_tensor * 5, None, num_inference_steps)
|
||||
else:
|
||||
comfy_pbar.update(1)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "comfyui-cogvideoxwrapper"
|
||||
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
|
||||
version = "1.5.0"
|
||||
description = "Diffusers wrapper for CogVideoX -models: https://github.com/THUDM/CogVideo"
|
||||
version = "1.5.1"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]
|
||||
|
||||
|
||||
13
readme.md
13
readme.md
@ -2,6 +2,19 @@
|
||||
|
||||
Spreadsheet (WIP) of supported models and their supported features: https://docs.google.com/spreadsheets/d/16eA6mSL8XkTcu9fSWkPSHfRIqyAKJbR1O99xnuGdCKY/edit?usp=sharing
|
||||
|
||||
## Update 9
|
||||
Added preliminary support for [Go-with-the-Flow](https://github.com/VGenAI-Netflix-Eyeline-Research/Go-with-the-Flow)
|
||||
|
||||
This uses LoRA weights available here: https://huggingface.co/Eyeline-Research/Go-with-the-Flow/tree/main
|
||||
|
||||
To create the input videos for the NoiseWarp process, I've added a node to KJNodes that works alongside my SplineEditor, and either [comfyui-inpaint-nodes](https://github.com/Acly/comfyui-inpaint-nodes) or just cv2 inpainting to create the cut and drag input videos.
|
||||
|
||||
The workflows are in the example_workflows -folder.
|
||||
|
||||
Quick video to showcase: First mask the subject, then use the cut and drag -workflow to create a video as seen here, then that video is used as input to the NoiseWarp node in the main workflow.
|
||||
|
||||
https://github.com/user-attachments/assets/112706b0-a38b-4c3c-b779-deba0827af4f
|
||||
|
||||
## BREAKING Update8
|
||||
|
||||
This is big one, and unfortunately to do the necessary cleanup and refactoring this will break every old workflow as they are.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user