mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
VAE fix, allow using fp32 VAE
This commit is contained in:
parent
6d4c99e77d
commit
ea0273c8ec
@ -535,6 +535,44 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
}
|
||||
|
||||
return (pipeline,)
|
||||
|
||||
#revion VAE
|
||||
|
||||
class CogVideoXVAELoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
||||
},
|
||||
"optional": {
|
||||
"precision": (["fp16", "fp32", "bf16"],
|
||||
{"default": "bf16"}
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VAE",)
|
||||
RETURN_NAMES = ("vae", )
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "Loads CogVideoX VAE model from 'ComfyUI/models/vae'"
|
||||
|
||||
def loadmodel(self, model_name, precision):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
with open(os.path.join(script_directory, 'configs', 'vae_config.json')) as f:
|
||||
vae_config = json.load(f)
|
||||
model_path = folder_paths.get_full_path("vae", model_name)
|
||||
vae_sd = load_torch_file(model_path)
|
||||
|
||||
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(dtype).to(offload_device)
|
||||
vae.load_state_dict(vae_sd)
|
||||
|
||||
return (vae,)
|
||||
|
||||
#region Tora
|
||||
class DownloadAndLoadToraModel:
|
||||
@classmethod
|
||||
@ -698,6 +736,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoControlNet": DownloadAndLoadCogVideoControlNet,
|
||||
"DownloadAndLoadToraModel": DownloadAndLoadToraModel,
|
||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||
"CogVideoXVAELoader": CogVideoXVAELoader,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -705,4 +744,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoControlNet": "(Down)load CogVideo ControlNet",
|
||||
"DownloadAndLoadToraModel": "(Down)load Tora Model",
|
||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
||||
}
|
||||
23
nodes.py
23
nodes.py
@ -350,6 +350,7 @@ class CogVideoImageEncode:
|
||||
"enable_tiling": ("BOOLEAN", {"default": False, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
|
||||
"mask": ("MASK", ),
|
||||
"noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Augment image with noise"}),
|
||||
"vae_override" : ("VAE", {"default": None, "tooltip": "Override the VAE model in the pipeline"}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -358,15 +359,21 @@ class CogVideoImageEncode:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0):
|
||||
def encode(self, pipeline, image, chunk_size=8, enable_tiling=False, mask=None, noise_aug_strength=0.0, vae_override=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
B, H, W, C = image.shape
|
||||
|
||||
vae = pipeline["pipe"].vae
|
||||
vae = pipeline["pipe"].vae if vae_override is None else vae_override
|
||||
vae.enable_slicing()
|
||||
model_name = pipeline.get("model_name", "")
|
||||
|
||||
if "1.5" in model_name or "1_5" in model_name:
|
||||
vae_scaling_factor = 1 / vae.config.scaling_factor
|
||||
else:
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
|
||||
if enable_tiling:
|
||||
from .mz_enable_vae_encode_tiling import enable_vae_encode_tiling
|
||||
@ -391,10 +398,14 @@ class CogVideoImageEncode:
|
||||
# input_image = input_image * (1 -mask)
|
||||
else:
|
||||
pipeline["pipe"].original_mask = None
|
||||
#input_image = input_image.permute(0, 3, 1, 2) # B, C, H, W
|
||||
#input_image = pipeline["pipe"].video_processor.preprocess(input_image).to(device, dtype=vae.dtype)
|
||||
#input_image = input_image.unsqueeze(2)
|
||||
|
||||
input_image = input_image * 2.0 - 1.0
|
||||
input_image = input_image.to(vae.dtype).to(device)
|
||||
input_image = input_image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||
|
||||
B, C, T, H, W = input_image.shape
|
||||
if noise_aug_strength > 0:
|
||||
input_image = add_noise_to_reference_video(input_image, ratio=noise_aug_strength)
|
||||
@ -417,7 +428,7 @@ class CogVideoImageEncode:
|
||||
elif hasattr(latents, "latents"):
|
||||
latents = latents.latents
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = vae_scaling_factor * latents
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
|
||||
latents_list.append(latents)
|
||||
|
||||
@ -972,6 +983,7 @@ class CogVideoDecode:
|
||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||
"vae_override": ("VAE", {"default": None}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -980,11 +992,12 @@ class CogVideoDecode:
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
|
||||
def decode(self, pipeline, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width,
|
||||
auto_tile_size=True, vae_override=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
latents = samples["samples"]
|
||||
vae = pipeline["pipe"].vae
|
||||
vae = pipeline["pipe"].vae if vae_override is None else vae_override
|
||||
|
||||
vae.enable_slicing()
|
||||
|
||||
|
||||
@ -162,12 +162,14 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self.original_mask = original_mask
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.video_processor.config.do_resize = False
|
||||
|
||||
if pab_config is not None:
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
self.input_with_padding = True
|
||||
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength,
|
||||
num_inference_steps, latents=None, freenoise=True, context_size=None, context_overlap=None
|
||||
@ -625,6 +627,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
logger.info(f"Sampling {num_frames} frames in {latent_frames} latent frames at {width}x{height} with {num_inference_steps} inference steps")
|
||||
|
||||
from .latent_preview import prepare_callback
|
||||
callback = prepare_callback(self.transformer, num_inference_steps)
|
||||
|
||||
# 9. Denoising loop
|
||||
comfy_pbar = ProgressBar(len(timesteps))
|
||||
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
||||
@ -926,6 +931,9 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None:
|
||||
callback(i, latents.detach()[-1], None, num_inference_steps)
|
||||
else:
|
||||
comfy_pbar.update(1)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user