diff --git a/__init__.py b/__init__.py index 07fac28..4bf3951 100644 --- a/__init__.py +++ b/__init__.py @@ -189,6 +189,7 @@ NODE_CONFIG = { "SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"}, "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, + "CFGZeroStar": {"class": CFGZeroStar, "name": "CFG Zero Star"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index ac7c296..2cf00e6 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -485,19 +485,13 @@ class TorchCompileModelWanVideo: def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only): m = model.clone() diffusion_model = m.get_model_object("diffusion_model") - torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit - is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod") - if is_compiled: - logging.info(f"Already compiled, not reapplying") - else: - logging.info(f"Not compiled, applying") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit try: if compile_transformer_blocks_only: for i, block in enumerate(diffusion_model.blocks): - if is_compiled: - compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) - else: - compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) + if hasattr(block, "_orig_mod"): + block = block._orig_mod + compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block) else: compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode) @@ -924,9 +918,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC current_step_index = 0 if current_step_index == 0: - if hasattr(diffusion_model, "teacache_state"): - delattr(diffusion_model, "teacache_state") - logging.info("\nResetting TeaCache state") + if (len(cond_or_uncond) == 1 and cond_or_uncond[0] == 1) or len(cond_or_uncond) == 2: + if hasattr(diffusion_model, "teacache_state"): + delattr(diffusion_model, "teacache_state") + logging.info("\nResetting TeaCache state") current_percent = current_step_index / (len(sigmas) - 1) c["transformer_options"]["current_percent"] = current_percent @@ -1176,4 +1171,59 @@ class SkipLayerGuidanceWanVideo: m.model_options["transformer_options"] = model_options + return (m, ) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "use_zero_init": ("BOOLEAN", {"default": True}), + "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + DESCRIPTION = "https://github.com/WeichenFan/CFG-Zero-star" + CATEGORY = "KJNodes/experimental" + EXPERIMENTAL = True + + def patch(self, model, use_zero_init, zero_star_steps): + def cfg_zerostar(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + timestep = args["timestep"] + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + + batch_size = cond.shape[0] + + noise_pred_text = cond + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = uncond.view(batch_size, -1) + + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + alpha = dot_product / squared_norm + alpha = alpha.view(batch_size, 1, 1, 1, 1) + + matched_step_index = (sigmas == timestep[0] ).nonzero() + if len(matched_step_index) > 0: + current_step_index = matched_step_index.item() + else: + for i in range(len(sigmas) - 1): + # walk from beginning of steps until crossing the timestep + if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: + current_step_index = i + break + else: + current_step_index = 0 + + if (current_step_index <= zero_star_steps) and use_zero_init: + noise_pred = noise_pred_text * 0 + else: + noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha) + return noise_pred + + m = model.clone() + m.set_model_sampler_cfg_function(cfg_zerostar) return (m, ) \ No newline at end of file