From 693157691657cd49776f45648ed326ca522eaef4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:53:12 +0200 Subject: [PATCH] update from upstream, ofs embeds --- custom_cogvideox_transformer_3d.py | 20 ++++++++------------ nodes.py | 7 ++++++- pipeline_cogvideox.py | 15 ++++++++++++--- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 2fa191a..10b9e4f 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -425,9 +425,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + self.ofs_proj = None self.ofs_embedding = None if ofs_embed_dim: + self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift) self.ofs_embedding = TimestepEmbedding(ofs_embed_dim, ofs_embed_dim, timestep_activation_fn) # same as time embeddings, for ofs # 3. Define spatio-temporal transformers blocks @@ -547,6 +549,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, controlnet_states: torch.Tensor = None, controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, @@ -563,26 +566,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) if self.ofs_embedding is not None: #1.5 I2V - emb_ofs = self.ofs_embedding(emb, timestep_cond) - emb = emb + emb_ofs + ofs_emb = self.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = self.ofs_embedding(ofs_emb) + emb = emb + ofs_emb # 2. Patch embedding p = self.config.patch_size p_t = self.config.patch_size_t - # We know that the hidden states height and width will always be divisible by patch_size. - # But, the number of frames may not be divisible by patch_size_t. So, we pad with the beginning frames. - # if p_t is not None: - # remaining_frames = 0 if num_frames % 2 == 0 else 1 - # first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1) - # hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1) - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) - text_seq_length = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] @@ -639,7 +637,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - output = output[:, remaining_frames:] (bb, tt, cc, hh, ww) = output.shape cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) @@ -711,7 +708,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - #output = output[:, remaining_frames:] if self.fastercache_counter >= self.fastercache_start_step + 1: (bb, tt, cc, hh, ww) = output.shape diff --git a/nodes.py b/nodes.py index 42af0ad..df73ae6 100644 --- a/nodes.py +++ b/nodes.py @@ -346,7 +346,7 @@ class CogVideoImageEncode: "image": ("IMAGE", ), }, "optional": { - "chunk_size": ("INT", {"default": 16, "min": 4}), + "chunk_size": ("INT", {"default": 16, "min": 4, "tooltip": "How many images to encode at once, lower values use less memory"}), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "mask": ("MASK", ), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), @@ -806,6 +806,7 @@ class CogVideoSampler: "controlnet": ("COGVIDECONTROLNET",), "tora_trajectory": ("TORAFEATURES", ), "fastercache": ("FASTERCACHEARGS", ), + #"sigmas": ("SIGMAS", ), } } @@ -879,6 +880,9 @@ class CogVideoSampler: cfg = [cfg for _ in range(steps)] else: assert len(cfg) == steps, "Length of cfg list must match number of steps" + + # if sigmas is not None: + # sigma_list = sigmas.tolist() autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() @@ -889,6 +893,7 @@ class CogVideoSampler: width = width, num_frames = num_frames, guidance_scale=cfg, + #sigmas=sigma_list if sigmas is not None else None, latents=samples["samples"] if samples is not None else None, image_cond_latents=image_cond_latents["samples"] if image_cond_latents is not None else None, denoise_strength=denoise_strength, diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 007987e..466eecb 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -369,6 +369,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): timesteps: Optional[List[int]] = None, guidance_scale: float = 6, denoise_strength: float = 1.0, + sigmas: Optional[List[float]] = None, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -429,7 +430,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ - + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 @@ -460,7 +461,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): prompt_embeds = prompt_embeds.to(self.vae.dtype) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if sigmas is None: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, sigmas=sigmas, device=device) self._num_timesteps = len(timesteps) # 5. Prepare latents. @@ -499,7 +503,6 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): freenoise=freenoise, ) latents = latents.to(self.vae.dtype) - #print("latents", latents.shape) # 5.5. if image_cond_latents is not None: @@ -579,6 +582,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): if self.transformer.config.use_rotary_positional_embeddings else None ) + # 7.6. Create ofs embeds if required + ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0) + if tora is not None and do_classifier_free_guidance: video_flow_features = tora["video_flow_features"].repeat(1, 2, 1, 1, 1).contiguous() @@ -617,6 +623,8 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): for param in module.parameters(): param.data = param.data.to(device) + logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps") + # 9. Denoising loop comfy_pbar = ProgressBar(len(timesteps)) with self.progress_bar(total=len(timesteps)) as progress_bar: @@ -873,6 +881,7 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + ofs=ofs_emb, return_dict=False, controlnet_states=controlnet_states, controlnet_weights=control_weights,