VAE fix, allow using fp32 VAE

This commit is contained in:
kijai 2024-11-11 22:34:43 +02:00
parent 6d4c99e77d
commit ea0273c8ec
3 changed files with 69 additions and 8 deletions

View File

@ -535,6 +535,44 @@ class DownloadAndLoadCogVideoGGUFModel:
} }
return (pipeline,) return (pipeline,)
#revion VAE
class CogVideoXVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"],
{"default": "bf16"}
),
}
}
RETURN_TYPES = ("VAE",)
RETURN_NAMES = ("vae", )
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'"
def loadmodel(self, model_name, precision):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f:
vae_config = json.load(f)
model_path = folder_paths.get_full_path("vae", model_name)
vae_sd = load_torch_file(model_path)
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device)
vae.load_state_dict(vae_sd)
return (vae,)
#region Tora #region Tora
class DownloadAndLoadToraModel: class DownloadAndLoadToraModel:
@classmethod @classmethod
@ -698,6 +736,7 @@ NODE_CLASS_MAPPINGS = {
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet, "DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
"DownloadAndLoadToraModel": DownloadAndLoadToraModel, "DownloadAndLoadToraModel": DownloadAndLoadToraModel,
"CogVideoLoraSelect": CogVideoLoraSelect, "CogVideoLoraSelect": CogVideoLoraSelect,
"CogVideoXVAELoader": CogVideoXVAELoader,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model", "DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
@ -705,4 +744,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet", "DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
"DownloadAndLoadToraModel": "(Down)load Tora Model", "DownloadAndLoadToraModel": "(Down)load Tora Model",
"CogVideoLoraSelect": "CogVideo LoraSelect", "CogVideoLoraSelect": "CogVideo LoraSelect",
"CogVideoXVAELoader": "CogVideoX VAE Loader",
} }

View File

@ -350,6 +350,7 @@ class CogVideoImageEncode:
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), "enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"mask": ("MASK", ), "mask": ("MASK", ),
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
"vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
}, },
} }
@ -358,15 +359,21 @@ class CogVideoImageEncode:
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0): def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0, vae_override=None):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
B, H, W, C = image.shape B, H, W, C = image.shape
vae = pipeline["pipe"].vae vae = pipeline["pipe"].vae if vae_override is None else vae_override
vae.enable_slicing() vae.enable_slicing()
model_name = pipeline.get("model_name", "")
if "1.5" in model_name or "1_5" in model_name:
vae_scaling_factor = 1 / vae.config.scaling_factor
else:
vae_scaling_factor = vae.config.scaling_factor
if enable_tiling: if enable_tiling:
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
@ -391,10 +398,14 @@ class CogVideoImageEncode:
# input_image = input_image * (1 -mask) # input_image = input_image * (1 -mask)
else: else:
pipeline["pipe"].original_mask = None pipeline["pipe"].original_mask = None
#input_image = input_image.permute(0, 3, 1, 2) # B, C, H, W
#input_image = pipeline["pipe"].video_processor.preprocess(input_image).to(device, dtype=vae.dtype)
#input_image = input_image.unsqueeze(2)
input_image = input_image * 2.0 - 1.0 input_image = input_image * 2.0 - 1.0
input_image = input_image.to(vae.dtype).to(device) input_image = input_image.to(vae.dtype).to(device)
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
B, C, T, H, W = input_image.shape B, C, T, H, W = input_image.shape
if noise_aug_strength > 0: if noise_aug_strength > 0:
input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength) input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength)
@ -417,7 +428,7 @@ class CogVideoImageEncode:
elif hasattr(latents, "latents"): elif hasattr(latents, "latents"):
latents = latents.latents latents = latents.latents
latents = vae.config.scaling_factor * latents latents = vae_scaling_factor * latents
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
latents_list.append(latents) latents_list.append(latents)
@ -972,6 +983,7 @@ class CogVideoDecode:
"tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
"vae_override": ("VAE", {"default": None}),
} }
} }
@ -980,11 +992,12 @@ class CogVideoDecode:
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True): def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width,
auto_tile_size=True, vae_override=None):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
latents = samples["samples"] latents = samples["samples"]
vae = pipeline["pipe"].vae vae = pipeline["pipe"].vae if vae_override is None else vae_override
vae.enable_slicing() vae.enable_slicing()

View File

@ -159,15 +159,17 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
) )
self.vae_scale_factor_temporal = ( self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
) )
self.original_mask = original_mask self.original_mask = original_mask
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.video_processor.config.do_resize = False
if pab_config is not None: if pab_config is not None:
set_pab_manager(pab_config) set_pab_manager(pab_config)
self.input_with_padding = True self.input_with_padding = True
def prepare_latents( def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
@ -625,6 +627,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps") logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
from .latent_preview import prepare_callback
callback = prepare_callback(self.transformer, num_inference_steps)
# 9. Denoising loop # 9. Denoising loop
comfy_pbar = ProgressBar(len(timesteps)) comfy_pbar = ProgressBar(len(timesteps))
with self.progress_bar(total=len(timesteps)) as progress_bar: with self.progress_bar(total=len(timesteps)) as progress_bar:
@ -926,7 +931,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
comfy_pbar.update(1) if callback is not None:
callback(i, latents.detach()[-1], None, num_inference_steps)
else:
comfy_pbar.update(1)
# Offload all models # Offload all models