mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
Merge branch 'kijai:main' into main
This commit is contained in:
commit
e71ae285ef
@ -52,7 +52,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
@torch.compiler.disable()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -126,7 +126,7 @@ class FusedCogVideoXAttnProcessor2_0:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
@torch.compiler.disable()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
|
|||||||
@ -79,11 +79,17 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|||||||
i: i + self.tile_sample_min_height,
|
i: i + self.tile_sample_min_height,
|
||||||
j: j + self.tile_sample_min_width,
|
j: j + self.tile_sample_min_width,
|
||||||
]
|
]
|
||||||
|
|
||||||
tile = self.encoder(tile)
|
tile = self.encoder(tile)
|
||||||
|
if not isinstance(tile, tuple):
|
||||||
|
tile = (tile,)
|
||||||
if self.quant_conv is not None:
|
if self.quant_conv is not None:
|
||||||
tile = self.quant_conv(tile)
|
tile = self.quant_conv(tile)
|
||||||
time.append(tile)
|
time.append(tile[0])
|
||||||
self._clear_fake_context_parallel_cache()
|
try:
|
||||||
|
self._clear_fake_context_parallel_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
row.append(torch.cat(time, dim=2))
|
row.append(torch.cat(time, dim=2))
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
result_rows = []
|
result_rows = []
|
||||||
@ -130,7 +136,10 @@ def _encode(
|
|||||||
if self.quant_conv is not None:
|
if self.quant_conv is not None:
|
||||||
z_intermediate = self.quant_conv(z_intermediate)
|
z_intermediate = self.quant_conv(z_intermediate)
|
||||||
h.append(z_intermediate)
|
h.append(z_intermediate)
|
||||||
self._clear_fake_context_parallel_cache()
|
try:
|
||||||
|
self._clear_fake_context_parallel_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
h = torch.cat(h, dim=2)
|
h = torch.cat(h, dim=2)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|||||||
16
nodes.py
16
nodes.py
@ -416,7 +416,10 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if compile == "torch":
|
if compile == "torch":
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
#pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor")
|
||||||
|
for i, block in enumerate(pipe.transformer.transformer_blocks):
|
||||||
|
if "CogVideoXBlock" in str(block):
|
||||||
|
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||||
elif compile == "onediff":
|
elif compile == "onediff":
|
||||||
from onediffx import compile_pipe
|
from onediffx import compile_pipe
|
||||||
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
|
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
|
||||||
@ -466,6 +469,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
"optional": {
|
"optional": {
|
||||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
|
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -474,7 +479,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None):
|
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None, compile="disabled"):
|
||||||
|
|
||||||
check_diffusers_version()
|
check_diffusers_version()
|
||||||
|
|
||||||
@ -564,7 +569,9 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
convert_fp8_linear(transformer, vae_dtype)
|
convert_fp8_linear(transformer, vae_dtype)
|
||||||
|
|
||||||
|
# compilation
|
||||||
|
for i, block in enumerate(transformer.transformer_blocks):
|
||||||
|
transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||||
with open(scheduler_path) as f:
|
with open(scheduler_path) as f:
|
||||||
scheduler_config = json.load(f)
|
scheduler_config = json.load(f)
|
||||||
|
|
||||||
@ -731,7 +738,8 @@ class DownloadAndLoadCogVideoControlNet:
|
|||||||
[
|
[
|
||||||
"TheDenk/cogvideox-2b-controlnet-hed-v1",
|
"TheDenk/cogvideox-2b-controlnet-hed-v1",
|
||||||
"TheDenk/cogvideox-2b-controlnet-canny-v1",
|
"TheDenk/cogvideox-2b-controlnet-canny-v1",
|
||||||
"TheDenk/cogvideox-5b-controlnet-hed-v1"
|
"TheDenk/cogvideox-5b-controlnet-hed-v1",
|
||||||
|
"TheDenk/cogvideox-5b-controlnet-canny-v1"
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
from einops import rearrange
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
@ -42,7 +42,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
@torch.compiler.disable()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -134,7 +134,7 @@ class FusedCogVideoXAttnProcessor2_0:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
@torch.compiler.disable()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -286,7 +286,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
self.attn_count = 0
|
self.attn_count = 0
|
||||||
self.last_attn = None
|
self.last_attn = None
|
||||||
self.block_idx = block_idx
|
self.block_idx = block_idx
|
||||||
|
#@torch.compiler.disable()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -294,6 +294,8 @@ class CogVideoXBlock(nn.Module):
|
|||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
|
video_flow_feature: Optional[torch.Tensor] = None,
|
||||||
|
fuser=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
@ -301,7 +303,14 @@ class CogVideoXBlock(nn.Module):
|
|||||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||||
hidden_states, encoder_hidden_states, temb
|
hidden_states, encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
|
# Tora Motion-guidance Fuser
|
||||||
|
if video_flow_feature is not None:
|
||||||
|
H, W = video_flow_feature.shape[-2:]
|
||||||
|
T = norm_hidden_states.shape[1] // H // W
|
||||||
|
h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W)
|
||||||
|
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||||
|
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||||
|
del h, fuser
|
||||||
# attention
|
# attention
|
||||||
if enable_pab():
|
if enable_pab():
|
||||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
||||||
@ -494,6 +503,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.fuser_list = None
|
||||||
|
|
||||||
# parallel
|
# parallel
|
||||||
#self.parallel_manager = None
|
#self.parallel_manager = None
|
||||||
|
|
||||||
@ -524,6 +535,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
controlnet_states: torch.Tensor = None,
|
controlnet_states: torch.Tensor = None,
|
||||||
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
|
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
|
||||||
|
video_flow_features: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
# if self.parallel_manager.cp_size > 1:
|
# if self.parallel_manager.cp_size > 1:
|
||||||
# (
|
# (
|
||||||
@ -574,31 +586,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
# 4. Transformer blocks
|
# 4. Transformer blocks
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if self.training and self.gradient_checkpointing:
|
hidden_states, encoder_hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
def create_custom_forward(module):
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
def custom_forward(*inputs):
|
temb=emb,
|
||||||
return module(*inputs)
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
timestep=timesteps if enable_pab() else None,
|
||||||
return custom_forward
|
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||||
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
)
|
||||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(block),
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states,
|
|
||||||
emb,
|
|
||||||
image_rotary_emb,
|
|
||||||
**ckpt_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, encoder_hidden_states = block(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
temb=emb,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
timestep=timesteps if enable_pab() else None,
|
|
||||||
)
|
|
||||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||||
controlnet_states_block = controlnet_states[i]
|
controlnet_states_block = controlnet_states[i]
|
||||||
controlnet_block_weight = 1.0
|
controlnet_block_weight = 1.0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user