diff --git a/examples/cogvideo_2b_vid2vid_test_example_02.json b/examples/cogvideo_2b_vid2vid_test_example_02.json index 87e3a74..2f70df7 100644 --- a/examples/cogvideo_2b_vid2vid_test_example_02.json +++ b/examples/cogvideo_2b_vid2vid_test_example_02.json @@ -818,7 +818,7 @@ "widgets_values": { "frame_rate": 8, "loop_count": 0, - "filename_prefix": "AnimateDiff", + "filename_prefix": "CogVideoX_vid2vid", "format": "video/nvenc_h264-mp4", "pix_fmt": "yuv420p", "bitrate": 10, diff --git a/nodes.py b/nodes.py index e431445..f38fee6 100644 --- a/nodes.py +++ b/nodes.py @@ -25,15 +25,11 @@ class DownloadAndLoadCogVideoModel: }, "optional": { - "precision": ( - [ - "fp16", - "fp32", - "bf16", - ], - {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}, + "precision": (["fp16", "fp32", "bf16"], + {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), - }, + "fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}), + } } RETURN_TYPES = ("COGVIDEOPIPE",) @@ -41,12 +37,16 @@ class DownloadAndLoadCogVideoModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, precision): + def loadmodel(self, model, precision, fp8_transformer): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + if fp8_transformer: + transformer_dtype = torch.float8_e4m3fn + else: + transformer_dtype = dtype if "2b" in model: base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B") @@ -63,7 +63,7 @@ class DownloadAndLoadCogVideoModel: local_dir=base_path, local_dir_use_symlinks=False, ) - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(dtype).to(offload_device) + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device) vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") @@ -247,22 +247,22 @@ class CogVideoSampler: pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler") elif scheduler == "DPM": pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler") - - latents = pipeline["pipe"]( - num_inference_steps=steps, - height = height, - width = width, - num_frames = num_frames, - t_tile_length = t_tile_length, - t_tile_overlap = t_tile_overlap, - guidance_scale=cfg, - latents=samples["samples"] if samples is not None else None, - denoise_strength=denoise_strength, - prompt_embeds=positive.to(dtype).to(device), - negative_prompt_embeds=negative.to(dtype).to(device), - generator=generator, - device=device - ) + with torch.autocast(mm.get_autocast_device(device)): + latents = pipeline["pipe"]( + num_inference_steps=steps, + height = height, + width = width, + num_frames = num_frames, + t_tile_length = t_tile_length, + t_tile_overlap = t_tile_overlap, + guidance_scale=cfg, + latents=samples["samples"] if samples is not None else None, + denoise_strength=denoise_strength, + prompt_embeds=positive.to(dtype).to(device), + negative_prompt_embeds=negative.to(dtype).to(device), + generator=generator, + device=device + ) pipe.transformer.to(offload_device) mm.soft_empty_cache() print(latents.shape) @@ -297,7 +297,7 @@ class CogVideoDecode: tile_overlap_factor_height=1 / 12, tile_overlap_factor_width=1 / 12, ) - + latents = latents.to(vae.dtype) latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / vae.config.scaling_factor * latents diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 496a3c3..e46759d 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -170,13 +170,13 @@ class CogVideoXPipeline(DiffusionPipeline): ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) else: latents = latents.to(device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device) latent_timestep = timesteps[:1] - noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype) + noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype) frames_needed = noise.shape[1] current_frames = latents.shape[1] @@ -400,6 +400,7 @@ class CogVideoXPipeline(DiffusionPipeline): if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_embeds = prompt_embeds.to(self.transformer.dtype) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -425,6 +426,7 @@ class CogVideoXPipeline(DiffusionPipeline): num_inference_steps, latents ) + latents = latents.to(self.transformer.dtype) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -551,18 +553,15 @@ class CogVideoXPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - else: - latents, old_pred_original_sample = self.scheduler.step( - noise_pred, - old_pred_original_sample, - t, - timesteps[i - 1] if i > 0 else None, - latents, - **extra_step_kwargs, - return_dict=False, - ) + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents.to(self.vae.dtype), + **extra_step_kwargs, + return_dict=False, + ) latents = latents.to(prompt_embeds.dtype) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):