mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
VAE fix, allow using fp32 VAE
This commit is contained in:
parent
6d4c99e77d
commit
ea0273c8ec
@ -535,6 +535,44 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
|
|
||||||
|
#revion VAE
|
||||||
|
|
||||||
|
class CogVideoXVAELoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"precision": (["fp16", "fp32", "bf16"],
|
||||||
|
{"default": "bf16"}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VAE",)
|
||||||
|
RETURN_NAMES = ("vae", )
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'"
|
||||||
|
|
||||||
|
def loadmodel(self, model_name, precision):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f:
|
||||||
|
vae_config = json.load(f)
|
||||||
|
model_path = folder_paths.get_full_path("vae", model_name)
|
||||||
|
vae_sd = load_torch_file(model_path)
|
||||||
|
|
||||||
|
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device)
|
||||||
|
vae.load_state_dict(vae_sd)
|
||||||
|
|
||||||
|
return (vae,)
|
||||||
|
|
||||||
#region Tora
|
#region Tora
|
||||||
class DownloadAndLoadToraModel:
|
class DownloadAndLoadToraModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -698,6 +736,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||||
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
||||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||||
|
"CogVideoXVAELoader": CogVideoXVAELoader,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
@ -705,4 +744,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||||
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
||||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||||
|
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
||||||
}
|
}
|
||||||
25
nodes.py
25
nodes.py
@ -350,6 +350,7 @@ class CogVideoImageEncode:
|
|||||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||||
"mask": ("MASK", ),
|
"mask": ("MASK", ),
|
||||||
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
||||||
|
"vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -358,15 +359,21 @@ class CogVideoImageEncode:
|
|||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0):
|
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0, vae_override=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
|
||||||
B, H, W, C = image.shape
|
B, H, W, C = image.shape
|
||||||
|
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae if vae_override is None else vae_override
|
||||||
vae.enable_slicing()
|
vae.enable_slicing()
|
||||||
|
model_name = pipeline.get("model_name", "")
|
||||||
|
|
||||||
|
if "1.5" in model_name or "1_5" in model_name:
|
||||||
|
vae_scaling_factor = 1 / vae.config.scaling_factor
|
||||||
|
else:
|
||||||
|
vae_scaling_factor = vae.config.scaling_factor
|
||||||
|
|
||||||
if enable_tiling:
|
if enable_tiling:
|
||||||
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
|
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
|
||||||
@ -391,10 +398,14 @@ class CogVideoImageEncode:
|
|||||||
# input_image = input_image * (1 -mask)
|
# input_image = input_image * (1 -mask)
|
||||||
else:
|
else:
|
||||||
pipeline["pipe"].original_mask = None
|
pipeline["pipe"].original_mask = None
|
||||||
|
#input_image = input_image.permute(0, 3, 1, 2) # B, C, H, W
|
||||||
|
#input_image = pipeline["pipe"].video_processor.preprocess(input_image).to(device, dtype=vae.dtype)
|
||||||
|
#input_image = input_image.unsqueeze(2)
|
||||||
|
|
||||||
input_image = input_image * 2.0 - 1.0
|
input_image = input_image * 2.0 - 1.0
|
||||||
input_image = input_image.to(vae.dtype).to(device)
|
input_image = input_image.to(vae.dtype).to(device)
|
||||||
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||||
|
|
||||||
B, C, T, H, W = input_image.shape
|
B, C, T, H, W = input_image.shape
|
||||||
if noise_aug_strength > 0:
|
if noise_aug_strength > 0:
|
||||||
input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength)
|
input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength)
|
||||||
@ -417,7 +428,7 @@ class CogVideoImageEncode:
|
|||||||
elif hasattr(latents, "latents"):
|
elif hasattr(latents, "latents"):
|
||||||
latents = latents.latents
|
latents = latents.latents
|
||||||
|
|
||||||
latents = vae.config.scaling_factor * latents
|
latents = vae_scaling_factor * latents
|
||||||
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
|
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
|
||||||
latents_list.append(latents)
|
latents_list.append(latents)
|
||||||
|
|
||||||
@ -972,6 +983,7 @@ class CogVideoDecode:
|
|||||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||||
|
"vae_override": ("VAE", {"default": None}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -980,11 +992,12 @@ class CogVideoDecode:
|
|||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
|
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width,
|
||||||
|
auto_tile_size=True, vae_override=None):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
latents = samples["samples"]
|
latents = samples["samples"]
|
||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae if vae_override is None else vae_override
|
||||||
|
|
||||||
vae.enable_slicing()
|
vae.enable_slicing()
|
||||||
|
|
||||||
|
|||||||
@ -159,15 +159,17 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
)
|
)
|
||||||
self.vae_scale_factor_temporal = (
|
self.vae_scale_factor_temporal = (
|
||||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||||
)
|
)
|
||||||
self.original_mask = original_mask
|
self.original_mask = original_mask
|
||||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
|
self.video_processor.config.do_resize = False
|
||||||
|
|
||||||
if pab_config is not None:
|
if pab_config is not None:
|
||||||
set_pab_manager(pab_config)
|
set_pab_manager(pab_config)
|
||||||
|
|
||||||
self.input_with_padding = True
|
self.input_with_padding = True
|
||||||
|
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
|
||||||
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
|
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
|
||||||
@ -625,6 +627,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||||
|
|
||||||
|
from .latent_preview import prepare_callback
|
||||||
|
callback = prepare_callback(self.transformer, num_inference_steps)
|
||||||
|
|
||||||
# 9. Denoising loop
|
# 9. Denoising loop
|
||||||
comfy_pbar = ProgressBar(len(timesteps))
|
comfy_pbar = ProgressBar(len(timesteps))
|
||||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||||
@ -926,7 +931,10 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
comfy_pbar.update(1)
|
if callback is not None:
|
||||||
|
callback(i, latents.detach()[-1], None, num_inference_steps)
|
||||||
|
else:
|
||||||
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
# Offload all models
|
# Offload all models
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user