mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-04-16 23:27:01 +08:00
parent
a5bd3c86c8
commit
b1ec996ba3
@ -189,6 +189,7 @@ NODE_CONFIG = {
|
||||
"SkipLayerGuidanceWanVideo": {"class": SkipLayerGuidanceWanVideo, "name": "Skip Layer Guidance WanVideo"},
|
||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||
"CFGZeroStar": {"class": CFGZeroStar, "name": "CFG Zero Star"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -485,19 +485,13 @@ class TorchCompileModelWanVideo:
|
||||
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only):
|
||||
m = model.clone()
|
||||
diffusion_model = m.get_model_object("diffusion_model")
|
||||
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
||||
is_compiled = hasattr(model.model.diffusion_model.blocks[0], "_orig_mod")
|
||||
if is_compiled:
|
||||
logging.info(f"Already compiled, not reapplying")
|
||||
else:
|
||||
logging.info(f"Not compiled, applying")
|
||||
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
||||
try:
|
||||
if compile_transformer_blocks_only:
|
||||
for i, block in enumerate(diffusion_model.blocks):
|
||||
if is_compiled:
|
||||
compiled_block = torch.compile(block._orig_mod, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
else:
|
||||
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
if hasattr(block, "_orig_mod"):
|
||||
block = block._orig_mod
|
||||
compiled_block = torch.compile(block, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
m.add_object_patch(f"diffusion_model.blocks.{i}", compiled_block)
|
||||
else:
|
||||
compiled_model = torch.compile(diffusion_model, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode)
|
||||
@ -924,9 +918,10 @@ Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaC
|
||||
current_step_index = 0
|
||||
|
||||
if current_step_index == 0:
|
||||
if hasattr(diffusion_model, "teacache_state"):
|
||||
delattr(diffusion_model, "teacache_state")
|
||||
logging.info("\nResetting TeaCache state")
|
||||
if (len(cond_or_uncond) == 1 and cond_or_uncond[0] == 1) or len(cond_or_uncond) == 2:
|
||||
if hasattr(diffusion_model, "teacache_state"):
|
||||
delattr(diffusion_model, "teacache_state")
|
||||
logging.info("\nResetting TeaCache state")
|
||||
|
||||
current_percent = current_step_index / (len(sigmas) - 1)
|
||||
c["transformer_options"]["current_percent"] = current_percent
|
||||
@ -1176,4 +1171,59 @@ class SkipLayerGuidanceWanVideo:
|
||||
m.model_options["transformer_options"] = model_options
|
||||
|
||||
|
||||
return (m, )
|
||||
|
||||
class CFGZeroStar:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"use_zero_init": ("BOOLEAN", {"default": True}),
|
||||
"zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "for zero init, starts from 0 so first step is always zeroed out if use_zero_init enabled"}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
DESCRIPTION = "https://github.com/WeichenFan/CFG-Zero-star"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, use_zero_init, zero_star_steps):
|
||||
def cfg_zerostar(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
timestep = args["timestep"]
|
||||
sigmas = args["model_options"]["transformer_options"]["sample_sigmas"]
|
||||
|
||||
batch_size = cond.shape[0]
|
||||
|
||||
noise_pred_text = cond
|
||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
||||
negative_flat = uncond.view(batch_size, -1)
|
||||
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
alpha = dot_product / squared_norm
|
||||
alpha = alpha.view(batch_size, 1, 1, 1, 1)
|
||||
|
||||
matched_step_index = (sigmas == timestep[0] ).nonzero()
|
||||
if len(matched_step_index) > 0:
|
||||
current_step_index = matched_step_index.item()
|
||||
else:
|
||||
for i in range(len(sigmas) - 1):
|
||||
# walk from beginning of steps until crossing the timestep
|
||||
if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0:
|
||||
current_step_index = i
|
||||
break
|
||||
else:
|
||||
current_step_index = 0
|
||||
|
||||
if (current_step_index <= zero_star_steps) and use_zero_init:
|
||||
noise_pred = noise_pred_text * 0
|
||||
else:
|
||||
noise_pred = uncond * alpha + cond_scale * (noise_pred_text - uncond * alpha)
|
||||
return noise_pred
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(cfg_zerostar)
|
||||
return (m, )
|
||||
Loading…
x
Reference in New Issue
Block a user