From ea0273c8ecc01b2d5216810cd2773854507b251f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:34:43 +0200 Subject: [PATCH] VAE fix, allow using fp32 VAE --- model_loading.py | 40 ++++++++++++++++++++++++++++++++++++++++ nodes.py | 25 +++++++++++++++++++------ pipeline_cogvideox.py | 12 ++++++++++-- 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/model_loading.py b/model_loading.py index 72af2a6..7c79d71 100644 --- a/model_loading.py +++ b/model_loading.py @@ -535,6 +535,44 @@ class DownloadAndLoadCogVideoGGUFModel: } return (pipeline,) + +#revion VAE + +class CogVideoXVAELoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}), + }, + "optional": { + "precision": (["fp16", "fp32", "bf16"], + {"default": "bf16"} + ), + } + } + + RETURN_TYPES = ("VAE",) + RETURN_NAMES = ("vae", ) + FUNCTION = "loadmodel" + CATEGORY = "CogVideoWrapper" + DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'" + + def loadmodel(self, model_name, precision): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f: + vae_config = json.load(f) + model_path = folder_paths.get_full_path("vae", model_name) + vae_sd = load_torch_file(model_path) + + vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device) + vae.load_state_dict(vae_sd) + + return (vae,) + #region Tora class DownloadAndLoadToraModel: @classmethod @@ -698,6 +736,7 @@ NODE_CLASS_MAPPINGS = { "DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet, "DownloadAndLoadToraModel": DownloadAndLoadToraModel, "CogVideoLoraSelect": CogVideoLoraSelect, + "CogVideoXVAELoader": CogVideoXVAELoader, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", @@ -705,4 +744,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet", "DownloadAndLoadToraModel": "(Down)load Tora Model", "CogVideoLoraSelect": "CogVideo LoraSelect", + "CogVideoXVAELoader": "CogVideoX VAE Loader", } \ No newline at end of file diff --git a/nodes.py b/nodes.py index df73ae6..3d83bc9 100644 --- a/nodes.py +++ b/nodes.py @@ -350,6 +350,7 @@ class CogVideoImageEncode: "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "mask": ("MASK", ), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), + "vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}), }, } @@ -358,15 +359,21 @@ class CogVideoImageEncode: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0): + def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0, vae_override=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() generator = torch.Generator(device=device).manual_seed(0) B, H, W, C = image.shape - vae = pipeline["pipe"].vae + vae = pipeline["pipe"].vae if vae_override is None else vae_override vae.enable_slicing() + model_name = pipeline.get("model_name", "") + + if "1.5" in model_name or "1_5" in model_name: + vae_scaling_factor = 1 / vae.config.scaling_factor + else: + vae_scaling_factor = vae.config.scaling_factor if enable_tiling: from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling @@ -391,10 +398,14 @@ class CogVideoImageEncode: # input_image = input_image * (1 -mask) else: pipeline["pipe"].original_mask = None - + #input_image = input_image.permute(0, 3, 1, 2) # B, C, H, W + #input_image = pipeline["pipe"].video_processor.preprocess(input_image).to(device, dtype=vae.dtype) + #input_image = input_image.unsqueeze(2) + input_image = input_image * 2.0 - 1.0 input_image = input_image.to(vae.dtype).to(device) input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + B, C, T, H, W = input_image.shape if noise_aug_strength > 0: input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength) @@ -417,7 +428,7 @@ class CogVideoImageEncode: elif hasattr(latents, "latents"): latents = latents.latents - latents = vae.config.scaling_factor * latents + latents = vae_scaling_factor * latents latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W latents_list.append(latents) @@ -972,6 +983,7 @@ class CogVideoDecode: "tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), + "vae_override": ("VAE", {"default": None}), } } @@ -980,11 +992,12 @@ class CogVideoDecode: FUNCTION = "decode" CATEGORY = "CogVideoWrapper" - def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True): + def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, + auto_tile_size=True, vae_override=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() latents = samples["samples"] - vae = pipeline["pipe"].vae + vae = pipeline["pipe"].vae if vae_override is None else vae_override vae.enable_slicing() diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 466eecb..694a85e 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -159,15 +159,17 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): ) self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) + ) self.original_mask = original_mask self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor.config.do_resize = False if pab_config is not None: set_pab_manager(pab_config) self.input_with_padding = True + def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None @@ -625,6 +627,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps") + from .latent_preview import prepare_callback + callback = prepare_callback(self.transformer, num_inference_steps) + # 9. Denoising loop comfy_pbar = ProgressBar(len(timesteps)) with self.progress_bar(total=len(timesteps)) as progress_bar: @@ -926,7 +931,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - comfy_pbar.update(1) + if callback is not None: + callback(i, latents.detach()[-1], None, num_inference_steps) + else: + comfy_pbar.update(1) # Offload all models