diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 5051f821d..b1faa63f4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = state_dict["img_in.proj.weight"].shape[1] #SkyReels img2video has 32 input channels + dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels dit_config["patch_size"] = [1, 2, 2] dit_config["out_channels"] = 16 dit_config["vec_in_dim"] = 768 diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py new file mode 100644 index 000000000..f3922e03d --- /dev/null +++ b/comfy_extras/nodes_video.py @@ -0,0 +1,75 @@ +import os +import av +import torch +import folder_paths +import json +from fractions import Fraction + + +class SaveWEBM: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "codec": (["vp9", "av1"],), + "fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "image/video" + + EXPERIMENTAL = True + + def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + + file = f"{filename}_{counter:05}_.webm" + container = av.open(os.path.join(full_output_folder, file), mode="w") + + if prompt is not None: + container.metadata["prompt"] = json.dumps(prompt) + + if extra_pnginfo is not None: + for x in extra_pnginfo: + container.metadata[x] = json.dumps(extra_pnginfo[x]) + + codec_map = {"vp9": "libvpx-vp9", "av1": "libaom-av1"} + stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) + stream.width = images.shape[-2] + stream.height = images.shape[-3] + stream.pix_fmt = "yuv420p" + stream.bit_rate = 0 + stream.options = {'crf': str(crf)} + + for frame in images: + frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + container.close() + + results = [{ + "filename": file, + "subfolder": subfolder, + "type": self.type + }] + + return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side + + +NODE_CLASS_MAPPINGS = { + "SaveWEBM": SaveWEBM, +} diff --git a/nodes.py b/nodes.py index b39adc654..4e68af79d 100644 --- a/nodes.py +++ b/nodes.py @@ -2265,6 +2265,7 @@ def init_builtin_extra_nodes(): "nodes_hooks.py", "nodes_load_3d.py", "nodes_cosmos.py", + "nodes_video.py", "nodes_lumina2.py", ] diff --git a/requirements.txt b/requirements.txt index 3bc945a1b..afbcb7cba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,8 @@ transformers>=4.28.1 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 -aiohttp +aiohttp>=3.11.8 +yarl>=1.18.0 pyyaml Pillow scipy @@ -19,3 +20,4 @@ psutil kornia>=0.7.1 spandrel soundfile +av