From 7fe4716f2decbbe3fec8803152f014e1f6513fbc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 03:16:33 +0300 Subject: [PATCH 1/6] tile encode fix --- mz_enable_vae_encode_tiling.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mz_enable_vae_encode_tiling.py b/mz_enable_vae_encode_tiling.py index 90b1d7d..a038bec 100644 --- a/mz_enable_vae_encode_tiling.py +++ b/mz_enable_vae_encode_tiling.py @@ -79,11 +79,15 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width, ] + tile = self.encoder(tile) if self.quant_conv is not None: tile = self.quant_conv(tile) - time.append(tile) - self._clear_fake_context_parallel_cache() + time.append(tile[0]) + try: + self._clear_fake_context_parallel_cache() + except: + pass row.append(torch.cat(time, dim=2)) rows.append(row) result_rows = [] @@ -130,7 +134,10 @@ def _encode( if self.quant_conv is not None: z_intermediate = self.quant_conv(z_intermediate) h.append(z_intermediate) - self._clear_fake_context_parallel_cache() + try: + self._clear_fake_context_parallel_cache() + except: + pass h = torch.cat(h, dim=2) return h From f9c1e11851bb064603548ccb3690f4309f005abe Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 21:49:48 +0300 Subject: [PATCH 2/6] diffusers backwards compatibility on tiled encode for some reason it's tuple in 0.31.0 --- mz_enable_vae_encode_tiling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mz_enable_vae_encode_tiling.py b/mz_enable_vae_encode_tiling.py index a038bec..544a649 100644 --- a/mz_enable_vae_encode_tiling.py +++ b/mz_enable_vae_encode_tiling.py @@ -81,6 +81,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: ] tile = self.encoder(tile) + if not isinstance(tile, tuple): + tile = (tile,) if self.quant_conv is not None: tile = self.quant_conv(tile) time.append(tile[0]) From 249e8d54d1d334f78c4fe6e00cdbf641b8077c5c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:46:37 +0300 Subject: [PATCH 3/6] Add cogvideox-5b-controlnet-canny-v1 --- nodes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 351dc53..f86c047 100644 --- a/nodes.py +++ b/nodes.py @@ -723,7 +723,8 @@ class DownloadAndLoadCogVideoControlNet: [ "TheDenk/cogvideox-2b-controlnet-hed-v1", "TheDenk/cogvideox-2b-controlnet-canny-v1", - "TheDenk/cogvideox-5b-controlnet-hed-v1" + "TheDenk/cogvideox-5b-controlnet-hed-v1", + "TheDenk/cogvideox-5b-controlnet-canny-v1" ], ), From 25f16462aa6a4b563a24e6eeb4ab37a026de86ac Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:50:13 +0300 Subject: [PATCH 4/6] torch compile maybe --- nodes.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index f86c047..d279033 100644 --- a/nodes.py +++ b/nodes.py @@ -408,7 +408,7 @@ class DownloadAndLoadCogVideoModel: if compile == "torch": torch._dynamo.config.suppress_errors = True pipe.transformer.to(memory_format=torch.channels_last) - pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) + pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") elif compile == "onediff": from onediffx import compile_pipe os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' @@ -458,6 +458,8 @@ class DownloadAndLoadCogVideoGGUFModel: "optional": { "pab_config": ("PAB_CONFIG", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), + "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), + } } @@ -466,7 +468,7 @@ class DownloadAndLoadCogVideoGGUFModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None): + def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None, block_edit=None, compile="disabled"): check_diffusers_version() @@ -556,7 +558,11 @@ class DownloadAndLoadCogVideoGGUFModel: from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, vae_dtype) - + # compilation + if compile == "torch": + torch._dynamo.config.suppress_errors = True + pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") with open(scheduler_path) as f: scheduler_config = json.load(f) From dcca0957439ae7cf3c6847f7bf62b0e453561609 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 26 Oct 2024 02:33:29 +0300 Subject: [PATCH 5/6] make torch compile work better --- custom_cogvideox_transformer_3d.py | 4 +- nodes.py | 12 +++--- videosys/cogvideox_transformer_3d.py | 56 +++++++++++++--------------- 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 7f0cf3b..9499094 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -52,7 +52,7 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, @@ -126,7 +126,7 @@ class FusedCogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, diff --git a/nodes.py b/nodes.py index d279033..5f74b48 100644 --- a/nodes.py +++ b/nodes.py @@ -408,7 +408,11 @@ class DownloadAndLoadCogVideoModel: if compile == "torch": torch._dynamo.config.suppress_errors = True pipe.transformer.to(memory_format=torch.channels_last) - pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") + #pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") + for i, block in enumerate(pipe.transformer.transformer_blocks): + if "CogVideoXBlock" in str(block): + print(block) + pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") elif compile == "onediff": from onediffx import compile_pipe os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' @@ -559,10 +563,8 @@ class DownloadAndLoadCogVideoGGUFModel: convert_fp8_linear(transformer, vae_dtype) # compilation - if compile == "torch": - torch._dynamo.config.suppress_errors = True - pipe.transformer.to(memory_format=torch.channels_last) - pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") + for i, block in enumerate(transformer.transformer_blocks): + transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") with open(scheduler_path) as f: scheduler_config = json.load(f) diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py index 6a482fa..b0e1aa5 100644 --- a/videosys/cogvideox_transformer_3d.py +++ b/videosys/cogvideox_transformer_3d.py @@ -9,7 +9,7 @@ # -------------------------------------------------------- from typing import Any, Dict, Optional, Tuple, Union - +from einops import rearrange import torch import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -42,7 +42,7 @@ class CogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, @@ -134,7 +134,7 @@ class FusedCogVideoXAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - + @torch.compiler.disable() def __call__( self, attn: Attention, @@ -286,7 +286,7 @@ class CogVideoXBlock(nn.Module): self.attn_count = 0 self.last_attn = None self.block_idx = block_idx - + #@torch.compiler.disable() def forward( self, hidden_states: torch.Tensor, @@ -294,6 +294,8 @@ class CogVideoXBlock(nn.Module): temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, timestep=None, + video_flow_feature: Optional[torch.Tensor] = None, + fuser=None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -301,7 +303,14 @@ class CogVideoXBlock(nn.Module): norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) - + # Tora Motion-guidance Fuser + if video_flow_feature is not None: + H, W = video_flow_feature.shape[-2:] + T = norm_hidden_states.shape[1] // H // W + h = rearrange(norm_hidden_states, "B (T H W) C -> (B T) C H W", H=H, W=W) + h = fuser(h, video_flow_feature.to(h), T=T) + norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T) + del h, fuser # attention if enable_pab(): broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx) @@ -494,6 +503,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False + self.fuser_list = None + # parallel #self.parallel_manager = None @@ -524,6 +535,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): return_dict: bool = True, controlnet_states: torch.Tensor = None, controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0, + video_flow_features: Optional[torch.Tensor] = None, ): # if self.parallel_manager.cp_size > 1: # ( @@ -574,31 +586,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - timestep=timesteps if enable_pab() else None, - ) + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + timestep=timesteps if enable_pab() else None, + video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, + fuser = self.fuser_list[i] if self.fuser_list is not None else None, + ) if (controlnet_states is not None) and (i < len(controlnet_states)): controlnet_states_block = controlnet_states[i] controlnet_block_weight = 1.0 From c17750ea0a995cf79fb049669cc85451642e5173 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 26 Oct 2024 02:35:29 +0300 Subject: [PATCH 6/6] remove print --- nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nodes.py b/nodes.py index 5f74b48..4648470 100644 --- a/nodes.py +++ b/nodes.py @@ -411,7 +411,6 @@ class DownloadAndLoadCogVideoModel: #pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") for i, block in enumerate(pipe.transformer.transformer_blocks): if "CogVideoXBlock" in str(block): - print(block) pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") elif compile == "onediff": from onediffx import compile_pipe