From 5a6bcfd6129dc6f07a64c6f67a1620cb02e616d6 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 19 Sep 2024 00:38:58 +0300 Subject: [PATCH] fix fp8 for I2V --- nodes.py | 20 ++++++++++++++++---- pipeline_cogvideox.py | 13 ++----------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/nodes.py b/nodes.py index 6ff9e6c..684cb20 100644 --- a/nodes.py +++ b/nodes.py @@ -57,8 +57,8 @@ class DownloadAndLoadCogVideoModel: offload_device = mm.unet_offload_device() mm.soft_empty_cache() - if "I2V" in model and fp8_transformer != "disabled": - raise NotImplementedError("fp8_transformer is not implemented yet for I2V -model") + #if "I2V" in model and fp8_transformer != "disabled": + # raise NotImplementedError("fp8_transformer is not implemented yet for I2V -model") dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] @@ -99,7 +99,15 @@ class DownloadAndLoadCogVideoModel: if name != "pos_embedding": param.data = param.data.to(torch.float8_e4m3fn) else: - transformer.to(torch.float8_e4m3fn) + for name, param in transformer.named_parameters(): + + if "patch_embed" not in name: + param.data = param.data.to(torch.float8_e4m3fn) + + else: + print(name) + print(param.data.dtype) + #transformer.to(torch.float8_e4m3fn) if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear @@ -238,6 +246,9 @@ class CogVideoTextEncode: return {"required": { "clip": ("CLIP",), "prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), } } @@ -246,7 +257,7 @@ class CogVideoTextEncode: FUNCTION = "process" CATEGORY = "CogVideoWrapper" - def process(self, clip, prompt): + def process(self, clip, prompt, strength=1.0): load_device = mm.text_encoder_device() offload_device = mm.text_encoder_offload_device() clip.tokenizer.t5xxl.pad_to_max_length = True @@ -255,6 +266,7 @@ class CogVideoTextEncode: tokens = clip.tokenize(prompt, return_word_ids=True) embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False) + embeds *= strength clip.cond_stage_model.to(offload_device) return (embeds, ) diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 06d0efc..30a3e03 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -412,7 +412,7 @@ class CogVideoXPipeline(DiffusionPipeline): if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_embeds = prompt_embeds.to(self.transformer.dtype) + prompt_embeds = prompt_embeds.to(self.vae.dtype) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -442,14 +442,11 @@ class CogVideoXPipeline(DiffusionPipeline): num_inference_steps, latents ) - latents = latents.to(self.transformer.dtype) + latents = latents.to(self.vae.dtype) print("latents", latents.shape) # 5.5. if image_cond_latents is not None: - print("image_cond_latents", image_cond_latents.shape) - #image_cond_latents = torch.cat(image_cond_latents, dim=0).to(self.transformer.dtype)#.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - padding_shape = ( batch_size, (latents.shape[1] - 1), @@ -457,10 +454,8 @@ class CogVideoXPipeline(DiffusionPipeline): height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) - print("padding_shape", padding_shape) latent_padding = torch.zeros(padding_shape, device=device, dtype=self.transformer.dtype) image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1) - print("image_cond_latents", image_cond_latents.shape) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -602,11 +597,7 @@ class CogVideoXPipeline(DiffusionPipeline): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if image_cond_latents is not None: - - latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents - print("latent_model_input",latent_model_input.shape) - print("image_cond_latents",image_cond_latents.shape) latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML