diff --git a/__init__.py b/__init__.py index c18ce7c..833f43b 100644 --- a/__init__.py +++ b/__init__.py @@ -186,6 +186,7 @@ NODE_CONFIG = { "WanVideoTeaCacheKJ": {"class": WanVideoTeaCacheKJ, "name": "WanVideo Tea Cache (native)"}, "WanVideoEnhanceAVideoKJ": {"class": WanVideoEnhanceAVideoKJ, "name": "WanVideo Enhance A Video (native)"}, "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, + "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/nodes.py b/nodes/nodes.py index 54cb5ab..e775f00 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -2644,4 +2644,68 @@ class TimerNodeKJ: timer.start_time = None return (any_input, timer, timer.elapsed) - \ No newline at end of file +class HunyuanVideoEncodeKeyframesToCond: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "positive": ("CONDITIONING", ), + "vae": ("VAE", ), + "start_frame": ("IMAGE", ), + "end_frame": ("IMAGE", ), + "num_frames": ("INT", {"default": 33, "min": 2, "max": 4096, "step": 1}), + "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), + "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), + }, + "optional": { + "negative": ("CONDITIONING", ), + } + } + + RETURN_TYPES = ("MODEL", "CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("model", "positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "KJNodes/videomodels" + + def encode(self, model, positive, start_frame, end_frame, num_frames, vae, tile_size, overlap, temporal_size, temporal_overlap, negative=None): + + model_clone = model.clone() + + model_clone.add_object_patch("concat_keys", ("concat_image",)) + + + x = (start_frame.shape[1] // 8) * 8 + y = (start_frame.shape[2] // 8) * 8 + + if start_frame.shape[1] != x or start_frame.shape[2] != y: + x_offset = (start_frame.shape[1] % 8) // 2 + y_offset = (start_frame.shape[2] % 8) // 2 + start_frame = start_frame[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + if end_frame.shape[1] != x or end_frame.shape[2] != y: + x_offset = (start_frame.shape[1] % 8) // 2 + y_offset = (start_frame.shape[2] % 8) // 2 + end_frame = end_frame[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + + video_frames = torch.zeros(num_frames-2, start_frame.shape[1], start_frame.shape[2], start_frame.shape[3], device=start_frame.device, dtype=start_frame.dtype) + video_frames = torch.cat([start_frame, video_frames, end_frame], dim=0) + + concat_latent = vae.encode_tiled(video_frames[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + + out_latent = {} + out_latent["samples"] = torch.zeros_like(concat_latent) + + out = [] + for conditioning in [positive, negative if negative is not None else []]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + n = [t[0], d] + c.append(n) + out.append(c) + if len(out) == 1: + out.append(out[0]) + return (model_clone, out[0], out[1], out_latent) \ No newline at end of file