diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 47b9488..12633b1 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -235,6 +235,7 @@ class CogVideoXBlock(nn.Module): image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, video_flow_feature: Optional[torch.Tensor] = None, fuser=None, + block_use_fastercache=False, fastercache_counter=0, fastercache_start_step=15, fastercache_device="cuda:0", @@ -257,18 +258,32 @@ class CogVideoXBlock(nn.Module): del h, fuser #region fastercache - B = norm_hidden_states.shape[0] - if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: - attn_hidden_states = ( - self.cached_hidden_states[1][:B] + - (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) - * 0.3 - ).to(norm_hidden_states.device, non_blocking=True) - attn_encoder_hidden_states = ( - self.cached_encoder_hidden_states[1][:B] + - (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) - * 0.3 - ).to(norm_hidden_states.device, non_blocking=True) + if block_use_fastercache: + B = norm_hidden_states.shape[0] + if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B: + attn_hidden_states = ( + self.cached_hidden_states[1][:B] + + (self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B]) + * 0.3 + ).to(norm_hidden_states.device, non_blocking=True) + attn_encoder_hidden_states = ( + self.cached_encoder_hidden_states[1][:B] + + (self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B]) + * 0.3 + ).to(norm_hidden_states.device, non_blocking=True) + else: + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mode=self.attention_mode, + ) + if fastercache_counter == fastercache_start_step: + self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] + self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)] + elif fastercache_counter > fastercache_start_step: + self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) + self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) else: attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, @@ -276,12 +291,6 @@ class CogVideoXBlock(nn.Module): image_rotary_emb=image_rotary_emb, attention_mode=self.attention_mode, ) - if fastercache_counter == fastercache_start_step: - self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] - self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)] - elif fastercache_counter > fastercache_start_step: - self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device)) - self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device)) hidden_states = hidden_states + gate_msa * attn_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states @@ -477,6 +486,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.fastercache_lf_step = 40 self.fastercache_hf_step = 30 self.fastercache_device = "cuda" + self.fastercache_num_blocks_to_cache = len(self.transformer_blocks) self.attention_mode = attention_mode def _set_gradient_checkpointing(self, module, value=False): @@ -577,7 +587,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # 2. Patch embedding p = self.config.patch_size p_t = self.config.patch_size_t - + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) @@ -597,7 +607,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): image_rotary_emb=image_rotary_emb, video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, + block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, + fastercache_start_step = self.fastercache_start_step, fastercache_device = self.fastercache_device ) @@ -665,7 +677,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): image_rotary_emb=image_rotary_emb, video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, + block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, + fastercache_start_step = self.fastercache_start_step, fastercache_device = self.fastercache_device ) #has_nan = torch.isnan(hidden_states).any() diff --git a/nodes.py b/nodes.py index 1f1a80b..29ffe2c 100644 --- a/nodes.py +++ b/nodes.py @@ -452,6 +452,8 @@ class CogVideoImageInterpolationEncode: "optional": { "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "mask": ("MASK", ), + "vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}), + }, } @@ -460,14 +462,21 @@ class CogVideoImageInterpolationEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, start_image, end_image, chunk_size=8, enable_tiling=False, mask=None): + def encode(self, pipeline, start_image, end_image, enable_tiling=False, mask=None, vae_override=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) B, H, W, C = start_image.shape - vae = pipeline["pipe"].vae + vae = pipeline["pipe"].vae if vae_override is None else vae_override + vae.enable_slicing() + model_name = pipeline.get("model_name", "") + + if ("1.5" in model_name or "1_5" in model_name): + vae_scaling_factor = 1 / vae.config.scaling_factor + else: + vae_scaling_factor = vae.config.scaling_factor vae.enable_slicing() if enable_tiling: @@ -500,8 +509,8 @@ class CogVideoImageInterpolationEncode: latents_list = [] # Encode the chunk of images - start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae.config.scaling_factor - end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae.config.scaling_factor + start_latents = vae.encode(start_image).latent_dist.sample(generator) * vae_scaling_factor + end_latents = vae.encode(end_image).latent_dist.sample(generator) * vae_scaling_factor start_latents = start_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W end_latents = end_latents.permute(0, 2, 1, 3, 4) # B, T, C, H, W @@ -769,6 +778,7 @@ class CogVideoXFasterCache: "hf_step": ("INT", {"default": 30, "min": 0, "max": 1024, "step": 1}), "lf_step": ("INT", {"default": 40, "min": 0, "max": 1024, "step": 1}), "cache_device": (["main_device", "offload_device", "cuda:1"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}), + "num_blocks_to_cache": ("INT", {"default": 42, "min": 0, "max": 1024, "step": 1, "tooltip": "Number of transformer blocks to cache, 5b model has 42 blocks, tradeoff between speed and memory"}), }, } @@ -777,7 +787,7 @@ class CogVideoXFasterCache: FUNCTION = "args" CATEGORY = "CogVideoWrapper" - def args(self, start_step, hf_step, lf_step, cache_device): + def args(self, start_step, hf_step, lf_step, cache_device, num_blocks_to_cache): device = mm.get_torch_device() offload_device = mm.unet_offload_device() if cache_device == "cuda:1": @@ -786,7 +796,8 @@ class CogVideoXFasterCache: "start_step" : start_step, "hf_step" : hf_step, "lf_step" : lf_step, - "cache_device" : device if cache_device != "offload_device" else offload_device + "cache_device" : device if cache_device != "offload_device" else offload_device, + "num_blocks_to_cache" : num_blocks_to_cache, } return (fastercache,) @@ -832,20 +843,25 @@ class CogVideoSampler: mm.soft_empty_cache() base_path = pipeline["base_path"] + model_name = pipeline.get("model_name", "") + supports_image_conds = True if "I2V" in model_name or "interpolation" in model_name.lower() else False assert "fun" not in base_path.lower(), "'Fun' models not supported in 'CogVideoSampler', use the 'CogVideoXFunSampler'" assert ( - "I2V" not in pipeline.get("model_name", "") or - "1.5" in pipeline.get("model_name", "") or - "1_5" in pipeline.get("model_name", "") or + "I2V" not in model_name or + "1.5" in model_name or + "1_5" in model_name or num_frames == 49 or context_options is not None ), "1.0 I2V model can only do 49 frames" if image_cond_latents is not None: - assert image_cond_latents["samples"].shape[0] == 1, "Image condition latents must be a single latent" - assert "I2V" in pipeline.get("model_name", ""), "Image condition latents only supported for I2V models" + assert supports_image_conds, "Image condition latents only supported for I2V and Interpolation models" + if "I2V" in model_name: + assert image_cond_latents["samples"].shape[1] == 1, "I2V model only supports single image condition latent" + elif "interpolation" in model_name.lower(): + assert image_cond_latents["samples"].shape[1] == 2, "Interpolation model needs two image condition latents" else: - assert "I2V" not in pipeline.get("model_name", ""), "Image condition latents required for I2V models" + assert not supports_image_conds, "Image condition latents required for I2V models" device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -885,6 +901,8 @@ class CogVideoSampler: pipe.transformer.fastercache_lf_step = fastercache["lf_step"] pipe.transformer.fastercache_hf_step = fastercache["hf_step"] pipe.transformer.fastercache_device = fastercache["cache_device"] + pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"] + log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}") else: pipe.transformer.use_fastercache = False pipe.transformer.fastercache_counter = 0 @@ -1001,6 +1019,8 @@ class CogVideoDecode: latents = samples["samples"] vae = pipeline["pipe"].vae if vae_override is None else vae_override + additional_frames = getattr(pipeline["pipe"], "additional_frames", 0) + vae.enable_slicing() if not pipeline["cpu_offloading"]: @@ -1024,7 +1044,8 @@ class CogVideoDecode: vae._clear_fake_context_parallel_cache() except: pass - frames = vae.decode(latents[:, :, pipeline["pipe"].additional_frames:]).sample + + frames = vae.decode(latents[:, :, additional_frames:]).sample vae.disable_tiling() if not pipeline["cpu_offloading"]: vae.to(offload_device)