diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 96a500be2..2cbb00d99 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,5 +1,5 @@ -As of the time of writing this you need this preview driver for best results: -https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html +As of the time of writing this you need this driver for best results: +https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html HOW TO RUN: @@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. + diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index 9274b4170..d72ece2ce 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -65,11 +65,11 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release AMD ROCm 6.4.4" + name: "Release AMD ROCm 7.1.1" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "rocm644" + cache_tag: "rocm711" python_minor: "12" python_patch: "10" rel_name: "amd" diff --git a/CODEOWNERS b/CODEOWNERS index b7aca9b26..51acc4986 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,4 @@ # Admins * @comfyanonymous * @kosinkadink +* @guill diff --git a/README.md b/README.md index b9300ab07..91fb510e1 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) + - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) diff --git a/app/user_manager.py b/app/user_manager.py index a2d376c0c..e2c00dab2 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -59,6 +59,9 @@ class UserManager(): user = "default" if args.multi_user and "comfy-user" in request.headers: user = request.headers["comfy-user"] + # Block System Users (use same error message to prevent probing) + if user.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise KeyError("Unknown user: " + user) if user not in self.users: raise KeyError("Unknown user: " + user) @@ -66,15 +69,16 @@ class UserManager(): return user def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): - user_directory = folder_paths.get_user_directory() - if type == "userdata": - root_dir = user_directory + root_dir = folder_paths.get_user_directory() else: raise KeyError("Unknown filepath type:" + type) user = self.get_request_user_id(request) - path = user_root = os.path.abspath(os.path.join(root_dir, user)) + user_root = folder_paths.get_public_user_directory(user) + if user_root is None: + return None + path = user_root # prevent leaving /{type} if os.path.commonpath((root_dir, user_root)) != root_dir: @@ -101,7 +105,11 @@ class UserManager(): name = name.strip() if not name: raise ValueError("username not provided") + if name.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = user_id + "_" + str(uuid.uuid4()) self.users[user_id] = name @@ -132,7 +140,10 @@ class UserManager(): if username in self.users.values(): return web.json_response({"error": "Duplicate username."}, status=400) - user_id = self.add_user(username) + try: + user_id = self.add_user(username) + except ValueError as e: + return web.json_response({"error": str(e)}, status=400) return web.json_response(user_id) @routes.get("/userdata") @@ -424,7 +435,7 @@ class UserManager(): return source dest = get_user_data_path(request, check_exists=False, param="dest") - if not isinstance(source, str): + if not isinstance(dest, str): return dest overwrite = request.query.get("overwrite", 'true') != "false" diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d2b60e347..209fc185b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -121,6 +121,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") +parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") +manager_group = parser.add_mutually_exclusive_group() +manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") +manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") + + vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") @@ -131,7 +137,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") -parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") +parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -167,6 +174,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8e110f45d..f1ca0151e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -431,6 +431,7 @@ class HunyuanVideo(LatentFormat): ] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + taesd_decoder_name = "taehv" class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 @@ -494,7 +495,7 @@ class Wan21(LatentFormat): ]).view(1, self.latent_channels, 1, 1, 1) - self.taesd_decoder_name = None #TODO + self.taesd_decoder_name = "lighttaew2_1" def process_in(self, latent): latents_mean = self.latents_mean.to(latent.device, latent.dtype) @@ -565,6 +566,7 @@ class Wan22(Wan21): def __init__(self): self.scale_factor = 1.0 + self.taesd_decoder_name = "lighttaew2_2" self.latents_mean = torch.tensor([ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, @@ -719,6 +721,7 @@ class HunyuanVideo15(LatentFormat): latent_channels = 32 latent_dimensions = 3 scale_factor = 1.03682 + taesd_decoder_name = "lighttaehy1_5" class Hunyuan3Dv2(LatentFormat): latent_channels = 64 diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a72f8cc47..2e8ef0687 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -40,7 +40,8 @@ class ChromaParams: out_dim: int hidden_dim: int n_layers: int - + txt_ids_dims: list + vec_in_dim: int diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 2472ab79c..60f2bdae2 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module): def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class YakMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None): + if yak_mlp: + return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + if mlp_silu_act: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): @@ -140,7 +169,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module): self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) @@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module): self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) self.flipped_img_txt = flipped_img_txt @@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module): modulation=True, mlp_silu_act=False, bias=True, + yak_mlp=False, dtype=None, device=None, operations=None @@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim_first = self.mlp_hidden_dim + self.yak_mlp = yak_mlp if mlp_silu_act: self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) self.mlp_act = SiLUActivation() else: self.mlp_act = nn.GELU(approximate="tanh") + if self.yak_mlp: + self.mlp_hidden_dim_first *= 2 + self.mlp_act = nn.SiLU() + # qkv and mlp_in self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out @@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v # compute activation in mlp stream, cat again and run second linear layer - mlp = self.mlp_act(mlp) + if self.yak_mlp: + mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] + else: + mlp = self.mlp_act(mlp) output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1a24e6d95..f40c2a7a9 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,7 +15,8 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, - Modulation + Modulation, + RMSNorm ) @dataclass @@ -34,11 +35,14 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + txt_ids_dims: list global_modulation: bool = False mlp_silu_act: bool = False ops_bias: bool = True default_ref_method: str = "offset" ref_index_scale: float = 1.0 + yak_mlp: bool = False + txt_norm: bool = False class Flux(nn.Module): @@ -76,6 +80,11 @@ class Flux(nn.Module): ) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + if params.txt_norm: + self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + else: + self.txt_norm = None + self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -86,6 +95,7 @@ class Flux(nn.Module): modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, proj_bias=params.ops_bias, + yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -94,7 +104,7 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) @@ -150,6 +160,8 @@ class Flux(nn.Module): y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.txt_norm is not None: + txt = self.txt_norm(txt) txt = self.txt_in(txt) vec_orig = vec @@ -171,7 +183,10 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -215,7 +230,10 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -326,8 +344,9 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) - if len(self.params.axes_dim) == 4: # Flux 2 - txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + if len(self.params.txt_ids_dims) > 0: + for i in self.params.txt_ids_dims: + txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c8643eb82..7d7e9112c 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -509,7 +509,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 @@ -517,15 +517,27 @@ class NextDiT(nn.Module): B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + H_tokens, W_tokens = H // pH, W // pW x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) diff --git a/comfy/lora.py b/comfy/lora.py index 36d26293a..3a9077869 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -313,6 +313,15 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Lumina2): + diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + return key_map diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7afe4a798..7d0517e61 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 2000 dit_config["out_channels"] = 128 dit_config["global_modulation"] = True - dit_config["vec_in_dim"] = None dit_config["mlp_silu_act"] = True dit_config["qkv_bias"] = False dit_config["ops_bias"] = False dit_config["default_ref_method"] = "index" dit_config["ref_index_scale"] = 10.0 + dit_config["txt_ids_dims"] = [3] patch_size = 1 else: dit_config["image_model"] = "flux" @@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 10000 dit_config["out_channels"] = 16 dit_config["qkv_bias"] = True + dit_config["txt_ids_dims"] = [] patch_size = 2 dit_config["in_channels"] = 16 @@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] + else: + dit_config["vec_in_dim"] = None dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') @@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model + dit_config["txt_ids_dims"] = [1, 2] + return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80..aeddbaefe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = lowvram_model_memory - loaded_memory if lowvram_model_memory == 0: @@ -1012,9 +1012,18 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 -if args.async_offload: - NUM_STREAMS = 2 +NUM_STREAMS = 0 +if args.async_offload is not None: + NUM_STREAMS = args.async_offload +else: + # Enable by default on Nvidia + if is_nvidia(): + NUM_STREAMS = 2 + +if args.disable_async_offload: + NUM_STREAMS = 0 + +if NUM_STREAMS > 0: logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) def current_stream(device): @@ -1030,7 +1039,10 @@ def current_stream(device): stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: + return None + + if torch.compiler.is_compiling(): return None if device in STREAMS: @@ -1043,7 +1055,9 @@ def get_offload_stream(device): elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=0)) + s1 = torch.cuda.Stream(device=device, priority=0) + s1.as_context = torch.cuda.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] stream_counters[device] = stream_counter @@ -1051,7 +1065,9 @@ def get_offload_stream(device): elif is_device_xpu(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.xpu.Stream(device=device, priority=0)) + s1 = torch.xpu.Stream(device=device, priority=0) + s1.as_context = torch.xpu.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] stream_counters[device] = stream_counter @@ -1069,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if dtype is None or weight.dtype == dtype: return weight if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) + if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f70..3eac77275 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ class LowVramPatch: else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -662,7 +674,16 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -676,20 +697,22 @@ class ModelPatcher: load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -723,9 +746,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -766,7 +791,7 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -778,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -831,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -849,13 +876,14 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -906,15 +934,18 @@ class ModelPatcher: m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1..61a2f0754 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if offload_stream is not None: wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) else: wf_context = contextlib.nullcontext() diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d2f3e7397..bb1fb860c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -235,8 +235,8 @@ class QuantizedTensor(torch.Tensor): def is_pinned(self): return self._qdata.is_pinned() - def is_contiguous(self): - return self._qdata.is_contiguous() + def is_contiguous(self, *arg, **kwargs): + return self._qdata.is_contiguous(*arg, **kwargs) # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) @@ -425,7 +425,8 @@ class TensorCoreFP8Layout(QuantizedLayout): @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - return plain_tensor * scale + plain_tensor.mul_(scale) + return plain_tensor @classmethod def get_plain_tensors(cls, qtensor): diff --git a/comfy/sd.py b/comfy/sd.py index 350fae92b..f9e5efab5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.ovis import comfy.model_patcher import comfy.lora @@ -60,6 +61,8 @@ import comfy.lora_convert import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd +import comfy.taesd.taehv +import comfy.latent_formats import comfy.ldm.flux.redux @@ -508,13 +511,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE + dim = sd["decoder.head.0.gamma"].shape[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -584,6 +588,35 @@ class VAE: self.process_input = lambda audio: audio self.working_dtypes = [torch.float32] self.crop_input = False + elif "decoder.22.bias" in sd: # taehv, taew and lighttae + self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_dim = 3 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + if self.latent_channels == 48: # Wan 2.2 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_output = lambda image: image + self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) + elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + else: + if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + latent_format=comfy.latent_formats.HunyuanVideo + else: + latent_format=None # lighttaew2_1 doesn't need scaling + self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) + self.process_input = self.process_output = lambda image: image + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) + self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -924,6 +957,7 @@ class CLIPType(Enum): QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 + OVIS = 21 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -955,6 +989,7 @@ class TEModel(Enum): MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 + QWEN3_2B = 17 def detect_te_model(sd): @@ -988,9 +1023,12 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: - if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] + if 'model.layers.0.self_attn.q_norm.weight' in sd: + if weight.shape[0] == 2560: + return TEModel.QWEN3_4B + elif weight.shape[0] == 2048: + return TEModel.QWEN3_2B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1118,6 +1156,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_4B: clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + elif te_model == TEModel.QWEN3_2B: + clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py new file mode 100644 index 000000000..3dfe1e4d4 --- /dev/null +++ b/comfy/taesd/taehv.py @@ -0,0 +1,171 @@ +# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple, deque + +import comfy.ops +operations=comfy.ops.disable_weight_init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out, act_func): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) + self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = act_func + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + + B, T, C, H, W = x.shape + if parallel: + x = x.reshape(B*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + BT, C, H, W = x.shape + T = BT // B + _x = x.reshape(B, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + BT, C, H, W = x.shape + T = BT // B + x = x.view(B, T, C, H, W) + else: + out = [] + work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + progress_bar = tqdm(range(T), disable=not show_progress_bar) + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.popleft() + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + del xt + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt.detach().clone() + else: + xt_new = b(xt, mem[i]) + mem[i] = xt.detach().clone() + del xt + work_queue.appendleft(TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt.detach().clone()) + if len(mem[i]) == b.stride: + B, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) + mem[i] = [] + work_queue.appendleft(TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): + work_queue.appendleft(TWorkItem(xt_next, i+1)) + del xt + else: + xt = b(xt) + work_queue.appendleft(TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + + +class TAEHV(nn.Module): + def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + super().__init__() + self.image_channels = 3 + self.patch_size = 1 + self.latent_channels = latent_channels + self.parallel = parallel + self.latent_format = latent_format + self.show_progress_bar = show_progress_bar + self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) + self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) + if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 + self.patch_size = 2 + if self.latent_channels == 32: # HunyuanVideo1.5 + act_func = nn.LeakyReLU(0.2, inplace=True) + else: # HunyuanVideo, Wan 2.1 + act_func = nn.ReLU(inplace=True) + + self.encoder = nn.Sequential( + conv(self.image_channels*self.patch_size**2, 64), act_func, + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + conv(64, self.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), act_func, + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + act_func, conv(n_f[3], self.image_channels*self.patch_size**2), + ) + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value + + def encode(self, x, **kwargs): + if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + if x.shape[1] % 4 != 0: + # pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + return self.process_out(x) + + def decode(self, x, **kwargs): + x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) + if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index cd4b5f76c..0d07ac8c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -100,6 +100,28 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Ovis25_2BConfig: + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Ovis25_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ovis25_2BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py new file mode 100644 index 000000000..81c9bd51c --- /dev/null +++ b/comfy/text_encoders/ovis.py @@ -0,0 +1,69 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch +import numbers + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data) + + +class OvisTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Ovis25_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) + + +class OvisTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen3_2b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 4004 and count_im_start < 1: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 1): + if tok_pairs[template_end + 1][0] == 25: + template_end += 1 + + out = out[:, template_end:] + return out, pooled, {} + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class OvisTEModel_(OvisTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return OvisTEModel_ diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json index 67688e82c..df5b5d7fe 100644 --- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json +++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json @@ -179,36 +179,36 @@ "special": false }, "151665": { - "content": "<|img|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151666": { - "content": "<|endofimg|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151667": { - "content": "<|meta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151668": { - "content": "<|endofmeta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false } }, "additional_special_tokens": [ diff --git a/comfy/utils.py b/comfy/utils.py index 4bd281057..37485e497 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -675,6 +675,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): return key_map +def z_image_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("n_layers", 0) + hidden_size = mmdit_config.get("dim", 0) + n_context_refiner = mmdit_config.get("n_refiner_layers", 2) + n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) + key_map = {} + + def add_block_keys(prefix_from, prefix_to, has_adaln=True): + for end in ("weight", "bias"): + k = "{}.attention.".format(prefix_from) + qkv = "{}.attention.qkv.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = { + "attention.norm_q.weight": "attention.q_norm.weight", + "attention.norm_k.weight": "attention.k_norm.weight", + "attention.to_out.0.weight": "attention.out.weight", + "attention.to_out.0.bias": "attention.out.bias", + "attention_norm1.weight": "attention_norm1.weight", + "attention_norm2.weight": "attention_norm2.weight", + "feed_forward.w1.weight": "feed_forward.w1.weight", + "feed_forward.w2.weight": "feed_forward.w2.weight", + "feed_forward.w3.weight": "feed_forward.w3.weight", + "ffn_norm1.weight": "ffn_norm1.weight", + "ffn_norm2.weight": "ffn_norm2.weight", + } + if has_adaln: + block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" + block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" + for k, v in block_map.items(): + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) + + for i in range(n_layers): + add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) + + for i in range(n_context_refiner): + add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) + + for i in range(n_noise_refiner): + add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) + + MAP_BASIC = [ + ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), + ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), + ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), + ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), + ("x_embedder.weight", "all_x_embedder.2-1.weight"), + ("x_embedder.bias", "all_x_embedder.2-1.bias"), + ("x_pad_token", "x_pad_token"), + ("cap_embedder.0.weight", "cap_embedder.0.weight"), + ("cap_embedder.1.weight", "cap_embedder.1.weight"), + ("cap_embedder.1.bias", "cap_embedder.1.bias"), + ("cap_pad_token", "cap_pad_token"), + ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), + ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), + ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), + ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), + ] + + for c, diffusers in MAP_BASIC: + key_map[diffusers] = "{}{}".format(output_prefix, c) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0d4389a6e..bfb77eb5f 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -13,6 +13,7 @@ from comfy.cli_args import args SERVER_FEATURE_FLAGS: Dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes + "extension": {"manager": {"supports_v4": True}}, } diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index bde37f90a..7231bf13c 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput): raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + extra_kwargs = {} + if format != VideoContainer.AUTO: + extra_kwargs["format"] = format.value + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: for key, value in metadata.items(): diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index d34590d28..a380ecc86 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel): mimeType: GeminiMimeType | None = Field(None) +class GeminiFileData(BaseModel): + fileUri: str | None = Field(None) + mimeType: GeminiMimeType | None = Field(None) + + class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) + fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py new file mode 100644 index 000000000..0a3b447c5 --- /dev/null +++ b/comfy_api_nodes/apis/kling_api.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, Field + + +class OmniProText2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniParamImage(BaseModel): + image_url: str = Field(...) + type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'") + + +class OmniParamVideo(BaseModel): + video_url: str = Field(...) + refer_type: str | None = Field(..., description="Can be 'base' or 'feature'") + keep_original_sound: str = Field(..., description="'yes' or 'no'") + + +class OmniProFirstLastFrameRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniProReferences2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'") + image_list: list[OmniParamImage] | None = Field( + None, max_length=7, description="Max length 4 when video is present." + ) + video_list: list[OmniParamVideo] | None = Field(None, max_length=1) + duration: str | None = Field(..., description="From 3 to 10.") + prompt: str = Field(...) + mode: str = Field("pro") + + +class TaskStatusVideoResult(BaseModel): + duration: str | None = Field(None, description="Total video duration") + id: str | None = Field(None, description="Generated video ID") + url: str | None = Field(None, description="URL for generated video") + + +class TaskStatusVideoResults(BaseModel): + videos: list[TaskStatusVideoResult] | None = Field(None) + + +class TaskStatusVideoResponseData(BaseModel): + created_at: int | None = Field(None, description="Task creation time") + updated_at: int | None = Field(None, description="Task update time") + task_status: str | None = None + task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") + task_id: str | None = Field(None, description="Task ID") + task_result: TaskStatusVideoResults | None = Field(None) + + +class TaskStatusVideoResponse(BaseModel): + code: int | None = Field(None, description="Error code") + message: str | None = Field(None, description="Error message") + request_id: str | None = Field(None, description="Request ID") + data: TaskStatusVideoResponseData | None = Field(None) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py index a55137afb..8328d1aa4 100644 --- a/comfy_api_nodes/apis/veo_api.py +++ b/comfy_api_nodes/apis/veo_api.py @@ -1,34 +1,21 @@ -from typing import Optional, Union -from enum import Enum +from typing import Optional from pydantic import BaseModel, Field -class Image2(BaseModel): - bytesBase64Encoded: str - gcsUri: Optional[str] = None - mimeType: Optional[str] = None +class VeoRequestInstanceImage(BaseModel): + bytesBase64Encoded: str | None = Field(None) + gcsUri: str | None = Field(None) + mimeType: str | None = Field(None) -class Image3(BaseModel): - bytesBase64Encoded: Optional[str] = None - gcsUri: str - mimeType: Optional[str] = None - - -class Instance1(BaseModel): - image: Optional[Union[Image2, Image3]] = Field( - None, description='Optional image to guide video generation' - ) +class VeoRequestInstance(BaseModel): + image: VeoRequestInstanceImage | None = Field(None) + lastFrame: VeoRequestInstanceImage | None = Field(None) prompt: str = Field(..., description='Text description of the video') -class PersonGeneration1(str, Enum): - ALLOW = 'ALLOW' - BLOCK = 'BLOCK' - - -class Parameters1(BaseModel): +class VeoRequestParameters(BaseModel): aspectRatio: Optional[str] = Field(None, examples=['16:9']) durationSeconds: Optional[int] = None enhancePrompt: Optional[bool] = None @@ -37,17 +24,18 @@ class Parameters1(BaseModel): description='Generate audio for the video. Only supported by veo 3 models.', ) negativePrompt: Optional[str] = None - personGeneration: Optional[PersonGeneration1] = None + personGeneration: str | None = Field(None, description="ALLOW or BLOCK") sampleCount: Optional[int] = None seed: Optional[int] = None storageUri: Optional[str] = Field( None, description='Optional Cloud Storage URI to upload the video' ) + resolution: str | None = Field(None) class VeoGenVidRequest(BaseModel): - instances: Optional[list[Instance1]] = None - parameters: Optional[Parameters1] = None + instances: list[VeoRequestInstance] | None = Field(None) + parameters: VeoRequestParameters | None = Field(None) class VeoGenVidResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 938a20f84..08f7b0f64 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer """ import base64 -import json import os -import time -import uuid from enum import Enum from io import BytesIO from typing import Literal @@ -20,6 +17,7 @@ from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.util import VideoCodec, VideoContainer from comfy_api_nodes.apis.gemini_api import ( GeminiContent, + GeminiFileData, GeminiGenerateContentRequest, GeminiGenerateContentResponse, GeminiImageConfig, @@ -38,10 +36,10 @@ from comfy_api_nodes.util import ( get_number_of_images, sync_op, tensor_to_base64_string, + upload_images_to_comfyapi, validate_string, video_to_base64_string, ) -from server import PromptServer GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB @@ -68,24 +66,43 @@ class GeminiImageModel(str, Enum): gemini_2_5_flash_image = "gemini-2.5-flash-image" -def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: - """ - Convert image tensor input to Gemini API compatible parts. - - Args: - image_input: Batch of image tensors from ComfyUI. - - Returns: - List of GeminiPart objects containing the encoded images. - """ +async def create_image_parts( + cls: type[IO.ComfyNode], + images: torch.Tensor, + image_limit: int = 0, +) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] - for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) + if image_limit < 0: + raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.") + total_images = get_number_of_images(images) + if total_images <= 0: + raise ValueError("No images provided to create_image_parts; at least one image is required.") + + # If image_limit == 0 --> use all images; otherwise clamp to image_limit. + effective_max = total_images if image_limit == 0 else min(total_images, image_limit) + + # Number of images we'll send as URLs (fileData) + num_url_images = min(effective_max, 10) # Vertex API max number of image links + reference_images_urls = await upload_images_to_comfyapi( + cls, + images, + max_images=num_url_images, + ) + for reference_image_url in reference_images_urls: + image_parts.append( + GeminiPart( + fileData=GeminiFileData( + mimeType=GeminiMimeType.image_png, + fileUri=reference_image_url, + ) + ) + ) + for idx in range(num_url_images, effective_max): image_parts.append( GeminiPart( inlineData=GeminiInlineData( mimeType=GeminiMimeType.image_png, - data=image_as_b64, + data=tensor_to_base64_string(images[idx]), ) ) ) @@ -338,8 +355,7 @@ class GeminiNode(IO.ComfyNode): # Add other modal parts if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if audio is not None: parts.extend(cls.create_audio_parts(audio)) if video is not None: @@ -364,29 +380,6 @@ class GeminiNode(IO.ComfyNode): ) output_text = get_text_from_response(response) - if output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(output_text or "Empty response from Gemini model...") @@ -562,8 +555,7 @@ class GeminiImage(IO.ComfyNode): image_config = GeminiImageConfig(aspectRatio=aspect_ratio) if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) @@ -582,30 +574,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - - output_text = get_text_from_response(response) - if output_text: - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(get_image_from_response(response), output_text) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -702,7 +671,7 @@ class GeminiImage2(IO.ComfyNode): if images is not None: if get_number_of_images(images) > 14: raise ValueError("The current maximum number of supported images is 14.") - parts.extend(create_image_parts(images)) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) @@ -725,30 +694,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - - output_text = get_text_from_response(response) - if output_text: - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(get_image_from_response(response), output_text) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 36852038b..850c44db6 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,15 +4,13 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -from __future__ import annotations -from typing import Optional, TypeVar -import math import logging - -from typing_extensions import override +import math import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( KlingCameraControl, KlingCameraConfig, @@ -50,25 +48,31 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) +from comfy_api_nodes.apis.kling_api import ( + OmniParamImage, + OmniParamVideo, + OmniProFirstLastFrameRequest, + OmniProReferences2VideoRequest, + OmniProText2VideoRequest, + TaskStatusVideoResponse, +) from comfy_api_nodes.util import ( - validate_image_dimensions, + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, validate_video_dimensions, validate_video_duration, - tensor_to_base64_string, - validate_string, - upload_audio_to_comfyapi, - download_url_to_image_tensor, - upload_video_to_comfyapi, - download_url_to_video_output, - sync_op, - ApiEndpoint, - poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.input.basic_types import AudioInput -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import ComfyExtension, IO KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -94,8 +98,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32 AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320 -R = TypeVar("R") - MODE_TEXT2VIDEO = { "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), @@ -130,6 +132,8 @@ MODE_START_END_FRAME = { "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), } """ Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. @@ -206,6 +210,20 @@ VOICES_CONFIG = { } +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput: + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), + response_model=TaskStatusVideoResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + def is_valid_camera_control_configs(configs: list[float]) -> bool: """Verifies that at least one camera control configuration is non-zero.""" return any(not math.isclose(value, 0.0) for value in configs) @@ -296,7 +314,7 @@ def get_video_from_response(response) -> KlingVideoResult: return video -def get_video_url_from_response(response) -> Optional[str]: +def get_video_url_from_response(response) -> str | None: """Returns the first video url from the Kling video generation task result. Will not raise an error if the response is not valid. """ @@ -315,7 +333,7 @@ def get_images_from_response(response) -> list[KlingImageResult]: return images -def get_images_urls_from_response(response) -> Optional[str]: +def get_images_urls_from_response(response) -> str | None: """Returns the list of image urls from the Kling image generation task result. Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. """ @@ -349,7 +367,7 @@ async def execute_text2video( model_mode: str, duration: str, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) task_creation_response = await sync_op( @@ -394,8 +412,8 @@ async def execute_image2video( model_mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_input_image(start_frame) @@ -451,9 +469,9 @@ async def execute_video_effect( model_name: str, duration: KlingVideoGenDuration, image_1: torch.Tensor, - image_2: Optional[torch.Tensor] = None, - model_mode: Optional[KlingVideoGenMode] = None, -) -> tuple[VideoFromFile, str, str]: + image_2: torch.Tensor | None = None, + model_mode: KlingVideoGenMode | None = None, +) -> tuple[InputImpl.VideoFromFile, str, str]: if dual_character: request_input_field = KlingDualCharacterEffectInput( model_name=model_name, @@ -499,13 +517,13 @@ async def execute_video_effect( async def execute_lipsync( cls: type[IO.ComfyNode], - video: VideoInput, - audio: Optional[AudioInput] = None, - voice_language: Optional[str] = None, - model_mode: Optional[str] = None, - text: Optional[str] = None, - voice_speed: Optional[float] = None, - voice_id: Optional[str] = None, + video: Input.Video, + audio: Input.Audio | None = None, + voice_language: str | None = None, + model_mode: str | None = None, + text: str | None = None, + voice_speed: float | None = None, + voice_id: str | None = None, ) -> IO.NodeOutput: if text: validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) @@ -740,6 +758,386 @@ class KlingTextToVideoNode(IO.ComfyNode): ) +class OmniProTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProTextToVideoNode", + display_name="Kling Omni Text to Video (Pro)", + category="api node/video/Kling", + description="Use text prompts to generate videos with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProText2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProFirstLastFrameNode", + display_name="Kling Omni First-Last-Frame to Video (Pro)", + category="api node/video/Kling", + description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("duration", options=["5", "10"]), + IO.Image.Input("first_frame"), + IO.Image.Input( + "end_frame", + optional=True, + tooltip="An optional end frame for the video. " + "This cannot be used simultaneously with 'reference_images'.", + ), + IO.Image.Input( + "reference_images", + optional=True, + tooltip="Up to 6 additional reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image | None = None, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + if end_frame is not None and reference_images is not None: + raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [ + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0], + type="first_frame", + ) + ] + if end_frame is not None: + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_list.append( + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0], + type="end_frame", + ) + ) + if reference_images is not None: + if get_number_of_images(reference_images) > 6: + raise ValueError("The maximum number of reference images allowed is 6.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProFirstLastFrameRequest( + model_name=model_name, + prompt=prompt, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageToVideoNode", + display_name="Kling Omni Image to Video (Pro)", + category="api node/video/Kling", + description="Use up to 7 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Image.Input( + "reference_images", + tooltip="Up to 7 reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_images: Input.Image, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + if get_number_of_images(reference_images) > 7: + raise ValueError("The maximum number of reference images is 7.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [] + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProVideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProVideoToVideoNode", + display_name="Kling Omni Video to Video (Pro)", + category="api node/video/Kling", + description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Video.Input("reference_video", tooltip="Video to use as a reference."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"), + refer_type="feature", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProEditVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProEditVideoNode", + display_name="Kling Omni Edit Video (Pro)", + category="api node/video/Kling", + description="Edit an existing video with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"), + refer_type="base", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=TaskStatusVideoResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=None, + duration=None, + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -787,7 +1185,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): negative_prompt: str, cfg_scale: float, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: return await execute_text2video( cls, @@ -854,8 +1252,8 @@ class KlingImage2VideoNode(IO.ComfyNode): mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: return await execute_image2video( cls, @@ -965,15 +1363,11 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), - IO.Combo.Input( - "aspect_ratio", - options=[i.value for i in KlingVideoGenAspectRatio], - default="16:9", - ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input( "mode", options=modes, - default=modes[2], + default=modes[8], tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), ], @@ -1170,7 +1564,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): category="api node/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ - IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), + IO.Image.Input( + "image", + tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1", + ), IO.Combo.Input( "effect_scene", options=[i.value for i in KlingSingleImageEffectsScene], @@ -1254,8 +1651,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, - audio: AudioInput, + video: Input.Video, + audio: Input.Audio, voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( @@ -1314,7 +1711,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, text: str, voice: str, voice_speed: float, @@ -1471,7 +1868,7 @@ class KlingImageGenerationNode(IO.ComfyNode): human_fidelity: float, n: int, aspect_ratio: KlingImageGenAspectRatio, - image: Optional[torch.Tensor] = None, + image: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) @@ -1533,6 +1930,11 @@ class KlingExtension(ComfyExtension): KlingImageGenerationNode, KlingSingleImageVideoEffectNode, KlingDualCharacterVideoEffectNode, + OmniProTextToVideoNode, + OmniProFirstLastFrameNode, + OmniProImageToVideoNode, + OmniProVideoToVideoNode, + OmniProEditVideoNode, ] diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index acf35d276..c8da5464b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,15 +1,10 @@ from io import BytesIO -from typing import Optional, Union -import json import os -import time -import uuid from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image -from server import PromptServer import folder_paths import base64 from comfy_api.latest import IO, ComfyExtension @@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode): def create_input_message_contents( cls, prompt: str, - image: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, + image: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, ) -> InputMessageContentList: """Create a list of input message contents from prompt and optional image.""" - content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [ + content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, persist_context: bool = False, model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, - images: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, - advanced_options: Optional[CreateModelResponseProperties] = None, + images: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, + advanced_options: CreateModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode): status_extractor=lambda response: response.status, completed_statuses=["incomplete", "completed"] ) - output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)) - - # Update history - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(output_text) + return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) class OpenAIInputFiles(IO.ComfyNode): @@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode): def execute( cls, truncation: bool, - instructions: Optional[str] = None, - max_output_tokens: Optional[int] = None, + instructions: str | None = None, + max_output_tokens: int | None = None, ) -> IO.NodeOutput: """ Configure advanced options for the OpenAI Chat Node. diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index d37e9e9b4..a54dc13ab 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,6 +1,7 @@ import base64 from io import BytesIO +import torch from typing_extensions import override from comfy_api.input_impl.video_types import VideoFromFile @@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollResponse, VeoGenVidRequest, VeoGenVidResponse, + VeoRequestInstance, + VeoRequestInstanceImage, + VeoRequestParameters, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): ) +class Veo3FirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Veo3FirstLastFrameNode", + display_name="Google Veo 3 First-Last-Frame to Video", + category="api node/video/Veo", + description="Generate video using prompt and first and last frames.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + IO.Int.Input( + "duration", + default=8, + min=4, + max=8, + step=2, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation", + ), + IO.Image.Input("first_frame", tooltip="Start frame"), + IO.Image.Input("last_frame", tooltip="End frame"), + IO.Combo.Input( + "model", + options=["veo-3.1-generate", "veo-3.1-fast-generate"], + default="veo-3.1-fast-generate", + ), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="Generate audio for the video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + first_frame: torch.Tensor, + last_frame: torch.Tensor, + model: str, + generate_audio: bool, + ): + model = MODELS_MAP[model] + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=[ + VeoRequestInstance( + prompt=prompt, + image=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png" + ), + lastFrame=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png" + ), + ), + ], + parameters=VeoRequestParameters( + aspectRatio=aspect_ratio, + personGeneration="ALLOW", + durationSeconds=duration, + enhancePrompt=True, # cannot be False for Veo3 + seed=seed, + generateAudio=generate_audio, + negativePrompt=negative_prompt, + resolution=resolution, + ), + ), + ) + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest( + operationName=initial_response.name, + ), + poll_interval=5.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + + class VeoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ VeoVideoGenerationNode, Veo3VideoGenerationNode, + Veo3FirstLastFrameNode, ] diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 632450d9b..0532bea9a 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,7 +4,7 @@ import logging import time import uuid from io import BytesIO -from typing import Optional, Union +from typing import Optional from urllib.parse import urlparse import aiohttp @@ -48,8 +48,9 @@ async def upload_images_to_comfyapi( image: torch.Tensor, *, max_images: int = 8, - mime_type: Optional[str] = None, - wait_label: Optional[str] = "Uploading", + mime_type: str | None = None, + wait_label: str | None = "Uploading", + show_batch_index: bool = True, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -59,11 +60,18 @@ async def upload_images_to_comfyapi( download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 + num_to_upload = min(batch_len, max_images) + batch_start_ts = time.monotonic() - for idx in range(min(batch_len, max_images)): + for idx in range(num_to_upload): tensor = image[idx] if is_batch else image img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) + + effective_label = wait_label + if wait_label and show_batch_index and num_to_upload > 1: + effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" + + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) download_urls.append(url) return download_urls @@ -95,6 +103,7 @@ async def upload_video_to_comfyapi( container: VideoContainer = VideoContainer.MP4, codec: VideoCodec = VideoCodec.H264, max_duration: Optional[int] = None, + wait_label: str | None = "Uploading", ) -> str: """ Uploads a single video to ComfyUI API and returns its download URL. @@ -119,15 +128,16 @@ async def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) async def upload_file_to_comfyapi( cls: type[IO.ComfyNode], file_bytes_io: BytesIO, filename: str, - upload_mime_type: Optional[str], - wait_label: Optional[str] = "Uploading", + upload_mime_type: str | None, + wait_label: str | None = "Uploading", + progress_origin_ts: float | None = None, ) -> str: """Uploads a single file to ComfyUI API and returns its download URL.""" if upload_mime_type is None: @@ -148,6 +158,7 @@ async def upload_file_to_comfyapi( file_bytes_io, content_type=upload_mime_type, wait_label=wait_label, + progress_origin_ts=progress_origin_ts, ) return create_resp.download_url @@ -155,27 +166,18 @@ async def upload_file_to_comfyapi( async def upload_file( cls: type[IO.ComfyNode], upload_url: str, - file: Union[BytesIO, str], + file: BytesIO | str, *, - content_type: Optional[str] = None, + content_type: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, - wait_label: Optional[str] = None, + wait_label: str | None = None, + progress_origin_ts: float | None = None, ) -> None: """ Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. - Args: - cls: Node class (provides auth context + UI progress hooks). - upload_url: Pre-signed PUT URL. - file: BytesIO or path string. - content_type: Explicit MIME type. If None, we *suppress* Content-Type. - max_retries: Maximum retry attempts. - retry_delay: Initial delay in seconds. - retry_backoff: Exponential backoff factor. - wait_label: Progress label shown in Comfy UI. - Raises: ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception """ @@ -198,7 +200,7 @@ async def upload_file( attempt = 0 delay = retry_delay - start_ts = time.monotonic() + start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() op_uuid = uuid.uuid4().hex[:8] while True: attempt += 1 diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c6ff6a30a..fbb080886 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -3,272 +3,312 @@ import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sa_solver -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict import latent_preview import torch import comfy.utils import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class BasicScheduler: +class BasicScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, scheduler, steps, denoise): + @classmethod + def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = int(steps/denoise) sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class KarrasScheduler: +class KarrasScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExponentialScheduler: + get_sigmas = execute + +class ExponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min): + @classmethod + def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) - return (sigmas, ) + return io.NodeOutput(sigmas) -class PolyexponentialScheduler: + get_sigmas = execute + +class PolyexponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class LaplaceScheduler: + get_sigmas = execute + +class LaplaceScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}), - "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class SDTurboScheduler: +class SDTurboScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 1, "min": 1, "max": 10}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, denoise): + @classmethod + def execute(cls, model, steps, denoise) -> io.NodeOutput: start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.get_model_object("model_sampling").sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) - return (sigmas, ) + return io.NodeOutput(sigmas) -class BetaSamplingScheduler: + get_sigmas = execute + +class BetaSamplingScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, alpha, beta): + @classmethod + def execute(cls, model, steps, alpha, beta) -> io.NodeOutput: sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) -class VPScheduler: + get_sigmas = execute + +class VPScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values - "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="VPScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, beta_d, beta_min, eps_s): + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) - return (sigmas, ) + return io.NodeOutput(sigmas) -class SplitSigmas: + get_sigmas = execute + +class SplitSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "step": ("INT", {"default": 0, "min": 0, "max": 10000}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, step): + @classmethod + def execute(cls, sigmas, step) -> io.NodeOutput: sigmas1 = sigmas[:step + 1] sigmas2 = sigmas[step:] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class SplitSigmasDenoise: + get_sigmas = execute + +class SplitSigmasDenoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, denoise): + @classmethod + def execute(cls, sigmas, denoise) -> io.NodeOutput: steps = max(sigmas.shape[-1] - 1, 0) total_steps = round(steps * denoise) sigmas1 = sigmas[:-(total_steps)] sigmas2 = sigmas[-(total_steps + 1):] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class FlipSigmas: + get_sigmas = execute + +class FlipSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[io.Sigmas.Input("sigmas")], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas): + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: if len(sigmas) == 0: - return (sigmas,) + return io.NodeOutput(sigmas) sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 - return (sigmas,) + return io.NodeOutput(sigmas) -class SetFirstSigma: + get_sigmas = execute + +class SetFirstSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "set_first_sigma" - - def set_first_sigma(self, sigmas, sigma): + @classmethod + def execute(cls, sigmas, sigma) -> io.NodeOutput: sigmas = sigmas.clone() sigmas[0] = sigma - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExtendIntermediateSigmas: + set_first_sigma = execute + +class ExtendIntermediateSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "steps": ("INT", {"default": 2, "min": 1, "max": 100}), - "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}), - "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}), - "spacing": (['linear', 'cosine', 'sine'],), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "extend" - - def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput: if start_at_sigma < 0: start_at_sigma = float("inf") @@ -299,27 +339,27 @@ class ExtendIntermediateSigmas: extended_sigmas = torch.FloatTensor(extended_sigmas) - return (extended_sigmas,) + return io.NodeOutput(extended_sigmas) + + extend = execute -class SamplingPercentToSigma: +class SamplingPercentToSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}), - "return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[io.Float.Output(display_name="sigma_value")] + ) - RETURN_TYPES = (IO.FLOAT,) - RETURN_NAMES = ("sigma_value",) - CATEGORY = "sampling/custom_sampling/sigmas" - - FUNCTION = "get_sigma" - - def get_sigma(self, model, sampling_percent, return_actual_sigma): + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") sigma_val = model_sampling.percent_to_sigma(sampling_percent) if return_actual_sigma: @@ -327,212 +367,234 @@ class SamplingPercentToSigma: sigma_val = model_sampling.sigma_max.item() elif sampling_percent == 1.0: sigma_val = model_sampling.sigma_min.item() - return (sigma_val,) + return io.NodeOutput(sigma_val) + + get_sigma = execute -class KSamplerSelect: +class KSamplerSelect(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect", + category="sampling/custom_sampling/samplers", + inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, sampler_name): + @classmethod + def execute(cls, sampler_name) -> io.NodeOutput: sampler = comfy.samplers.sampler_object(sampler_name) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_3M_SDE: + get_sampler = execute + +class SamplerDPMPP_3M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, noise_device): + @classmethod + def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_3m_sde" else: sampler_name = "dpmpp_3m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2M_SDE: + get_sampler = execute + +class SamplerDPMPP_2M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"solver_type": (['midpoint', 'heun'], ), - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, eta, s_noise, noise_device): + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerDPMPP_SDE: +class SamplerDPMPP_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, r, noise_device): + @classmethod + def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2S_Ancestral: + get_sampler = execute + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestral: + get_sampler = execute + +class SamplerEulerAncestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestralCFGPP: + get_sampler = execute + +class SamplerEulerAncestralCFGPP(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}), - }} - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP", + display_name="SamplerEulerAncestralCFG++", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler( "euler_ancestral_cfg_pp", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerLMS: + get_sampler = execute + +class SamplerLMS(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 4, "min": 1, "max": 100}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS", + category="sampling/custom_sampling/samplers", + inputs=[io.Int.Input("order", default=4, min=1, max=100)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order): + @classmethod + def execute(cls, order) -> io.NodeOutput: sampler = comfy.samplers.ksampler("lms", {"order": order}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMAdaptative: + get_sampler = execute + +class SamplerDPMAdaptative(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 3, "min": 2, "max": 3}), - "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, "s_noise":s_noise }) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerER_SDE(ComfyNodeABC): +class SamplerER_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}), - "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}), - "eta": ( - IO.FLOAT, - {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."}, - ), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, max_stage, eta, s_noise): + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): eta = 0 s_noise = 0 @@ -548,32 +610,33 @@ class SamplerER_SDE(ComfyNodeABC): sampler_name = "er_sde" sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerSASolver(ComfyNodeABC): +class SamplerSASolver(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},), - "sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},), - "sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},), - "predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}), - "corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}), - "use_pece": (IO.BOOLEAN, {}), - "simple_order_2": (IO.BOOLEAN, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") start_sigma = model_sampling.percent_to_sigma(sde_start_percent) end_sigma = model_sampling.percent_to_sigma(sde_end_percent) @@ -591,7 +654,9 @@ class SamplerSASolver(ComfyNodeABC): "simple_order_2": simple_order_2, }, ) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute class Noise_EmptyNoise: @@ -612,30 +677,31 @@ class Noise_RandomNoise: batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) -class SamplerCustom: +class SamplerCustom(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") - - FUNCTION = "sample" - - CATEGORY = "sampling/custom_sampling" - - def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -664,52 +730,58 @@ class SamplerCustom: out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) + + sample = execute class Guider_Basic(comfy.samplers.CFGGuider): def set_conds(self, positive): self.inner_set_conds({"positive": positive}) -class BasicGuider: +class BasicGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "conditioning": ("CONDITIONING", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="BasicGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, conditioning): + @classmethod + def execute(cls, model, conditioning) -> io.NodeOutput: guider = Guider_Basic(model) guider.set_conds(conditioning) - return (guider,) + return io.NodeOutput(guider) -class CFGGuider: + get_guider = execute + +class CFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, positive, negative, cfg): + @classmethod + def execute(cls, model, positive, negative, cfg) -> io.NodeOutput: guider = comfy.samplers.CFGGuider(model) guider.set_conds(positive, negative) guider.set_cfg(cfg) - return (guider,) + return io.NodeOutput(guider) + + get_guider = execute class Guider_DualCFG(comfy.samplers.CFGGuider): def set_cfg(self, cfg1, cfg2, nested=False): @@ -740,84 +812,88 @@ class Guider_DualCFG(comfy.samplers.CFGGuider): out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 -class DualCFGGuider: +class DualCFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "cond1": ("CONDITIONING", ), - "cond2": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "style": (["regular", "nested"],), - } - } + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput: guider = Guider_DualCFG(model) guider.set_conds(cond1, cond2, negative) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) - return (guider,) + return io.NodeOutput(guider) -class DisableNoise: + get_guider = execute + +class DisableNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required":{ - } - } + def define_schema(cls): + return io.Schema( + node_id="DisableNoise", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("NOISE",) - FUNCTION = "get_noise" - CATEGORY = "sampling/custom_sampling/noise" - - def get_noise(self): - return (Noise_EmptyNoise(),) - - -class RandomNoise(DisableNoise): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "noise_seed": ("INT", { - "default": 0, - "min": 0, - "max": 0xffffffffffffffff, - "control_after_generate": True, - }), - } - } + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(Noise_EmptyNoise()) - def get_noise(self, noise_seed): - return (Noise_RandomNoise(noise_seed),) + get_noise = execute -class SamplerCustomAdvanced: +class RandomNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"noise": ("NOISE", ), - "guider": ("GUIDER", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="RandomNoise", + category="sampling/custom_sampling/noise", + inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") + @classmethod + def execute(cls, noise_seed) -> io.NodeOutput: + return io.NodeOutput(Noise_RandomNoise(noise_seed)) - FUNCTION = "sample" + get_noise = execute - CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider, sampler, sigmas, latent_image): +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -842,29 +918,32 @@ class SamplerCustomAdvanced: out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) + sample = execute -class AddNoise: +class AddNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "noise": ("NOISE", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="AddNoise", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - - FUNCTION = "add_noise" - - CATEGORY = "_for_testing/custom_sampling/noise" - - def add_noise(self, model, noise, sigmas, latent_image): + @classmethod + def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: if len(sigmas) == 0: - return latent_image + return io.NodeOutput(latent_image) latent = latent_image latent_image = latent["samples"] @@ -888,46 +967,50 @@ class AddNoise: out = latent.copy() out["samples"] = noisy - return (out,) + return io.NodeOutput(out) + + add_noise = execute -NODE_CLASS_MAPPINGS = { - "SamplerCustom": SamplerCustom, - "BasicScheduler": BasicScheduler, - "KarrasScheduler": KarrasScheduler, - "ExponentialScheduler": ExponentialScheduler, - "PolyexponentialScheduler": PolyexponentialScheduler, - "LaplaceScheduler": LaplaceScheduler, - "VPScheduler": VPScheduler, - "BetaSamplingScheduler": BetaSamplingScheduler, - "SDTurboScheduler": SDTurboScheduler, - "KSamplerSelect": KSamplerSelect, - "SamplerEulerAncestral": SamplerEulerAncestral, - "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP, - "SamplerLMS": SamplerLMS, - "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, - "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, - "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, - "SamplerDPMAdaptative": SamplerDPMAdaptative, - "SamplerER_SDE": SamplerER_SDE, - "SamplerSASolver": SamplerSASolver, - "SplitSigmas": SplitSigmas, - "SplitSigmasDenoise": SplitSigmasDenoise, - "FlipSigmas": FlipSigmas, - "SetFirstSigma": SetFirstSigma, - "ExtendIntermediateSigmas": ExtendIntermediateSigmas, - "SamplingPercentToSigma": SamplingPercentToSigma, +class CustomSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerCustom, + BasicScheduler, + KarrasScheduler, + ExponentialScheduler, + PolyexponentialScheduler, + LaplaceScheduler, + VPScheduler, + BetaSamplingScheduler, + SDTurboScheduler, + KSamplerSelect, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMAdaptative, + SamplerER_SDE, + SamplerSASolver, + SplitSigmas, + SplitSigmasDenoise, + FlipSigmas, + SetFirstSigma, + ExtendIntermediateSigmas, + SamplingPercentToSigma, + CFGGuider, + DualCFGGuider, + BasicGuider, + RandomNoise, + DisableNoise, + AddNoise, + SamplerCustomAdvanced, + ] - "CFGGuider": CFGGuider, - "DualCFGGuider": DualCFGGuider, - "BasicGuider": BasicGuider, - "RandomNoise": RandomNoise, - "DisableNoise": DisableNoise, - "AddNoise": AddNoise, - "SamplerCustomAdvanced": SamplerCustomAdvanced, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", -} +async def comfy_entrypoint() -> CustomSamplersExtension: + return CustomSamplersExtension() diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py new file mode 100644 index 000000000..4789d7d53 --- /dev/null +++ b/comfy_extras/nodes_dataset.py @@ -0,0 +1,1432 @@ +import logging +import os +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f, weights_only=True) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 899608149..54c66ef68 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile from pathlib import Path +from PIL import Image +import numpy as np + +import uuid def normalize_path(path): return path.replace('\\', '/') @@ -34,58 +38,6 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): - image_path = folder_paths.get_annotated_filepath(image['image']) - mask_path = folder_paths.get_annotated_filepath(image['mask']) - normal_path = folder_paths.get_annotated_filepath(image['normal']) - lineart_path = folder_paths.get_annotated_filepath(image['lineart']) - - load_image_node = nodes.LoadImage() - output_image, ignore_mask = load_image_node.load_image(image=image_path) - ignore_image, output_mask = load_image_node.load_image(image=mask_path) - normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - - video = None - - if image['recording'] != "": - recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - - video = VideoFromFile(recording_video_path) - - return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video - -class Load3DAnimation(): - @classmethod - def INPUT_TYPES(s): - input_dir = os.path.join(folder_paths.get_input_directory(), "3d") - - os.makedirs(input_dir, exist_ok=True) - - input_path = Path(input_dir) - base_path = Path(folder_paths.get_input_directory()) - - files = [ - normalize_path(str(file_path.relative_to(base_path))) - for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'} - ] - - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D_ANIMATION", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") @@ -120,7 +72,8 @@ class Preview3D(): "model_file": ("STRING", {"default": "", "multiline": False}), }, "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) + "camera_info": ("LOAD3D_CAMERA", {}), + "bg_image": ("IMAGE", {}) }} OUTPUT_NODE = True @@ -133,50 +86,33 @@ class Preview3D(): def process(self, model_file, **kwargs): camera_info = kwargs.get("camera_info", None) + bg_image = kwargs.get("bg_image", None) + + bg_image_path = None + if bg_image is not None: + + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_array) + + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + + bg_image_path = f"temp/{filename}" return { "ui": { - "result": [model_file, camera_info] - } - } - -class Preview3DAnimation(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] + "result": [model_file, camera_info, bg_image_path] } } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, - "Load3DAnimation": Load3DAnimation, "Preview3D": Preview3D, - "Preview3DAnimation": Preview3DAnimation } NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D", - "Load3DAnimation": "Load 3D - Animation", - "Preview3D": "Preview 3D", - "Preview3DAnimation": "Preview 3D - Animation" + "Load3D": "Load 3D & Animation", + "Preview3D": "Preview 3D & Animation", } diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 9e6ec6780..cb24ab709 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,15 +1,13 @@ -import datetime -import json import logging import os import numpy as np import safetensors import torch -from PIL import Image, ImageDraw, ImageFont -from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint -import tqdm +from tqdm.auto import trange +from PIL import Image, ImageDraw, ImageFont +from typing_extensions import override import comfy.samplers import comfy.sd @@ -18,9 +16,9 @@ import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers -from comfy.cli_args import args -from comfy.comfy_types.node_typing import IO from comfy.weight_adapter import adapters, adapter_maps +from comfy_api.latest import ComfyExtension, io, ui +from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -56,7 +54,18 @@ def process_cond_list(d, prefix=""): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__( + self, + loss_fn, + optimizer, + loss_callback=None, + batch_size=1, + grad_acc=1, + total_steps=1, + seed=0, + training_dtype=torch.bfloat16, + real_dataset=None, + ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback @@ -65,54 +74,138 @@ class TrainSampler(comfy.samplers.Sampler): self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + def fwd_bwd( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred, x0) + if bwd: + bwd_loss = loss / self.grad_acc + bwd_loss.backward() + return loss + + def sample( + self, + model_wrap, + sigmas, + extra_args, + callback, + noise, + latent_image=None, + denoise_mask=None, + disable_pbar=False, + ): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() - for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): - noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) - indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - xt = model_wrap.inner_model.model_sampling.noise_scaling( - batch_sigmas, - batch_noise, - batch_latent, - False + ui_pbar = ProgressBar(self.total_steps) + for i in ( + pbar := trange( + self.total_steps, + desc="Training LoRA", + smoothing=0.01, + disable=not comfy.utils.PROGRESS_BAR_ENABLED, ) - x0 = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(batch_sigmas), - torch.zeros_like(batch_noise), - batch_latent, - False + ): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( + self.seed + i * 1000 ) + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - model_wrap.conds["positive"] = [ - cond[i] for i in indicies - ] - batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - with torch.autocast(xt.device.type, dtype=self.training_dtype): - x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) - loss = self.loss_fn(x0_pred, x0) - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) - if (i+1) % self.grad_acc == 0: + if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -134,233 +227,6 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): - """Utility function to load and process a list of images. - - Args: - image_files: List of image filenames - input_dir: Base directory containing the images - resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") - - Returns: - torch.Tensor: Batch of processed images - """ - if not image_files: - raise ValueError("No valid images found in input") - - output_images = [] - - for file in image_files: - image_path = os.path.join(input_dir, file) - img = node_helpers.pillow(Image.open, image_path) - - if img.mode == "I": - img = img.point(lambda i: i * (1 / 255)) - img = img.convert("RGB") - - if w is None and h is None: - w, h = img.size[0], img.size[1] - - # Resize image to first image - if img.size[0] != w or img.size[1] != h: - if resize_method == "Stretch": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "Crop": - img = img.crop((0, 0, w, h)) - elif resize_method == "Pad": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "None": - raise ValueError( - "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." - ) - - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - output_images.append(img_tensor) - - return torch.cat(output_images, dim=0) - - -class LoadImageSetNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ( - [ - f - for f in os.listdir(folder_paths.get_input_directory()) - if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) - ], - {"image_upload": True, "allow_batch": True}, - ) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - INPUT_IS_LIST = True - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - @classmethod - def VALIDATE_INPUTS(s, images, resize_method): - filenames = images[0] if isinstance(images[0], list) else images - - for image in filenames: - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True - - def load_images(self, input_files, resize_method): - input_dir = folder_paths.get_input_directory() - valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] - image_files = [ - f - for f in input_files - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, input_dir, resize_method) - return (output_tensor,) - - -class LoadImageSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder, resize_method): - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - image_files = [ - f - for f in os.listdir(sub_input_dir) - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) - return (output_tensor,) - - -class LoadImageTextSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}), - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - "width": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The width to resize the images to. -1 means use the original width.", - }, - ), - "height": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The height to resize the images to. -1 means use the original height.", - }, - ) - }, - } - - RETURN_TYPES = ("IMAGE", IO.CONDITIONING,) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images and caption from a directory for training." - - def load_images(self, folder, clip, resize_method, width=None, height=None): - if clip is None: - raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - - logging.info(f"Loading images from folder: {folder}") - - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - - image_files = [] - for item in os.listdir(sub_input_dir): - path = os.path.join(sub_input_dir, item) - if any(item.lower().endswith(ext) for ext in valid_extensions): - image_files.append(path) - elif os.path.isdir(path): - # Support kohya-ss/sd-scripts folder structure - repeat = 1 - if item.split("_")[0].isdigit(): - repeat = int(item.split("_")[0]) - image_files.extend([ - os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) - ] * repeat) - - caption_file_path = [ - f.replace(os.path.splitext(f)[1], ".txt") - for f in image_files - ] - captions = [] - for caption_file in caption_file_path: - caption_path = os.path.join(sub_input_dir, caption_file) - if os.path.exists(caption_path): - with open(caption_path, "r", encoding="utf-8") as f: - caption = f.read().strip() - captions.append(caption) - else: - captions.append("") - - width = width if width != -1 else None - height = height if height != -1 else None - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) - - logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") - - logging.info(f"Encoding captions from {sub_input_dir}.") - conditions = [] - empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) - for text in captions: - if text == "": - conditions.append(empty_cond) - tokens = clip.tokenize(text) - conditions.extend(clip.encode_from_tokens_scheduled(tokens)) - logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") - return (output_tensor, conditions) - - def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -379,10 +245,14 @@ def draw_loss_graph(loss_map, steps): return img -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): +def find_all_highest_child_module_with_forward( + model: torch.nn.Module, result=None, name=None +): if result is None: result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + elif hasattr(model, "forward") and not isinstance( + model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result @@ -396,12 +266,13 @@ def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward + def fwd(args, kwargs): return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + m.org_forward = org_forward m.forward = checkpointing_fwd @@ -412,130 +283,126 @@ def unpatch(m): del m.org_forward -class TrainLoraNode: +class TrainLoraNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), - "latents": ( - "LATENT", - { - "tooltip": "The Latents to use for training, serve as dataset/input of the model." - }, + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode", + display_name="Train LoRA", + category="training", + is_experimental=True, + is_input_list=True, # All inputs become lists + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input( + "latents", + tooltip="The Latents to use for training, serve as dataset/input of the model.", ), - "positive": ( - IO.CONDITIONING, - {"tooltip": "The positive conditioning to use for training."}, + io.Conditioning.Input( + "positive", tooltip="The positive conditioning to use for training." ), - "batch_size": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 10000, - "step": 1, - "tooltip": "The batch size to use for training.", - }, + io.Int.Input( + "batch_size", + default=1, + min=1, + max=10000, + tooltip="The batch size to use for training.", ), - "grad_accumulation_steps": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 1024, - "step": 1, - "tooltip": "The number of gradient accumulation steps to use for training.", - } + io.Int.Input( + "grad_accumulation_steps", + default=1, + min=1, + max=1024, + tooltip="The number of gradient accumulation steps to use for training.", ), - "steps": ( - IO.INT, - { - "default": 16, - "min": 1, - "max": 100000, - "tooltip": "The number of steps to train the LoRA for.", - }, + io.Int.Input( + "steps", + default=16, + min=1, + max=100000, + tooltip="The number of steps to train the LoRA for.", ), - "learning_rate": ( - IO.FLOAT, - { - "default": 0.0005, - "min": 0.0000001, - "max": 1.0, - "step": 0.000001, - "tooltip": "The learning rate to use for training.", - }, + io.Float.Input( + "learning_rate", + default=0.0005, + min=0.0000001, + max=1.0, + step=0.0000001, + tooltip="The learning rate to use for training.", ), - "rank": ( - IO.INT, - { - "default": 8, - "min": 1, - "max": 128, - "tooltip": "The rank of the LoRA layers.", - }, + io.Int.Input( + "rank", + default=8, + min=1, + max=128, + tooltip="The rank of the LoRA layers.", ), - "optimizer": ( - ["AdamW", "Adam", "SGD", "RMSprop"], - { - "default": "AdamW", - "tooltip": "The optimizer to use for training.", - }, + io.Combo.Input( + "optimizer", + options=["AdamW", "Adam", "SGD", "RMSprop"], + default="AdamW", + tooltip="The optimizer to use for training.", ), - "loss_function": ( - ["MSE", "L1", "Huber", "SmoothL1"], - { - "default": "MSE", - "tooltip": "The loss function to use for training.", - }, + io.Combo.Input( + "loss_function", + options=["MSE", "L1", "Huber", "SmoothL1"], + default="MSE", + tooltip="The loss function to use for training.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", - }, + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), - "training_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for training."}, + io.Combo.Input( + "training_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for training.", ), - "lora_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for lora."}, + io.Combo.Input( + "lora_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for lora.", ), - "algorithm": ( - list(adapter_maps.keys()), - {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + io.Combo.Input( + "algorithm", + options=list(adapter_maps.keys()), + default=list(adapter_maps.keys())[0], + tooltip="The algorithm to use for training.", ), - "gradient_checkpointing": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Use gradient checkpointing for training.", - } + io.Boolean.Input( + "gradient_checkpointing", + default=True, + tooltip="Use gradient checkpointing for training.", ), - "existing_lora": ( - folder_paths.get_filename_list("loras") + ["[None]"], - { - "default": "[None]", - "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", - }, + io.Combo.Input( + "existing_lora", + options=folder_paths.get_filename_list("loras") + ["[None]"], + default="[None]", + tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), - }, - } + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="Model with LoRA applied" + ), + io.Custom("LORA_MODEL").Output( + display_name="lora", tooltip="LoRA weights" + ), + io.Custom("LOSS_MAP").Output( + display_name="loss_map", tooltip="Loss history" + ), + io.Int.Output(display_name="steps", tooltip="Total training steps"), + ], + ) - RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) - RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") - FUNCTION = "train" - CATEGORY = "training" - EXPERIMENTAL = True - - def train( - self, + @classmethod + def execute( + cls, model, latents, positive, @@ -553,13 +420,74 @@ class TrainLoraNode: gradient_checkpointing, existing_lora, ): + # Extract scalars from lists (due to is_input_list=True) + model = model[0] + batch_size = batch_size[0] + steps = steps[0] + grad_accumulation_steps = grad_accumulation_steps[0] + learning_rate = learning_rate[0] + rank = rank[0] + optimizer = optimizer[0] + loss_function = loss_function[0] + seed = seed[0] + training_dtype = training_dtype[0] + lora_dtype = lora_dtype[0] + algorithm = algorithm[0] + gradient_checkpointing = gradient_checkpointing[0] + existing_lora = existing_lora[0] + + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list + + # Handle conditioning - either single list or list of lists + if len(positive) == 1: + positive = positive[0] # Single conditioning list + else: + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + positive = flat_positive + mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images @@ -591,9 +519,7 @@ class TrainLoraNode: shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get( - f"{key}.dora_scale", None - ) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale @@ -605,7 +531,9 @@ class TrainLoraNode: adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to(lora_dtype) + train_adapter = existing_adapter.to_train().to( + lora_dtype + ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( @@ -629,7 +557,9 @@ class TrainLoraNode: if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( - torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + torch.zeros( + m.bias.shape, dtype=lora_dtype, requires_grad=True + ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias @@ -657,24 +587,31 @@ class TrainLoraNode: # setup models if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + for m in find_all_highest_child_module_with_forward( + mp.model.diffusion_model + ): patch(m) mp.model.requires_grad_(False) - comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + comfy.model_management.load_models_gpu( + [mp], memory_required=1e20, force_full_load=True + ) # Setup sampler and guider like in test script loss_map = {"loss": []} + def loss_callback(loss): loss_map["loss"].append(loss) + train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, - total_steps=steps*grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, seed=seed, - training_dtype=dtype + training_dtype=dtype, + real_dataset=latents if multi_res else None, ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input @@ -684,12 +621,15 @@ class TrainLoraNode: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, - seed=noise.seed + seed=noise.seed, ) finally: for m in mp.model.modules(): @@ -702,111 +642,118 @@ class TrainLoraNode: for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return (mp, lora_sd, loss_map, steps + existing_steps) + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader: - def __init__(self): - self.loaded_lora = None +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader", + display_name="Load LoRA Model", + category="loaders", + is_experimental=True, + inputs=[ + io.Model.Input( + "model", tooltip="The diffusion model the LoRA will be applied to." + ), + io.Custom("LORA_MODEL").Input( + "lora", tooltip="The LoRA model to apply to the diffusion model." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + tooltip="How strongly to modify the diffusion model. This value can be negative.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="The modified diffusion model." + ), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL",) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) - FUNCTION = "load_lora_model" - - CATEGORY = "loaders" - DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." - EXPERIMENTAL = True - - def load_lora_model(self, model, lora, strength_model): + def execute(cls, model, lora, strength_model): if strength_model == 0: - return (model, ) + return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) - return (model_lora, ) + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) + return io.NodeOutput(model_lora) -class SaveLoRA: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA", + display_name="Save LoRA Weights", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LORA_MODEL").Input( + "lora", + tooltip="The LoRA model to save. Do not use the model with LoRA layers.", + ), + io.String.Input( + "prefix", + default="loras/ComfyUI_trained_lora", + tooltip="The prefix to use for the saved LoRA file.", + ), + io.Int.Input( + "steps", + optional=True, + tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + ), + ], + outputs=[], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora": ( - IO.LORA_MODEL, - { - "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." - }, - ), - "prefix": ( - "STRING", - { - "default": "loras/ComfyUI_trained_lora", - "tooltip": "The prefix to use for the saved LoRA file.", - }, - ), - }, - "optional": { - "steps": ( - IO.INT, - { - "forceInput": True, - "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", - }, - ), - }, - } - - RETURN_TYPES = () - FUNCTION = "save" - CATEGORY = "loaders" - EXPERIMENTAL = True - OUTPUT_NODE = True - - def save(self, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + def execute(cls, lora, prefix, steps=None): + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(prefix, output_dir) + ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) - return {} + return io.NodeOutput() -class LossGraphNode: - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode", + display_name="Plot Loss Graph", + category="training", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LOSS_MAP").Input( + "loss", tooltip="Loss map from training node." + ), + io.String.Input( + "filename_prefix", + default="loss_graph", + tooltip="Prefix for the saved loss graph image.", + ), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "loss": (IO.LOSS_MAP, {"default": {}}), - "filename_prefix": (IO.STRING, {"default": "loss_graph"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "plot_loss" - OUTPUT_NODE = True - CATEGORY = "training" - EXPERIMENTAL = True - DESCRIPTION = "Plots the loss graph and saves it to the output directory." - - def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 @@ -849,47 +796,27 @@ class LossGraphNode: (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + # Convert PIL image to tensor for PreviewImage + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] - date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - img.save( - os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), - pnginfo=metadata, - ) - return { - "ui": { - "images": [ - { - "filename": f"{filename_prefix}_{date}.png", - "subfolder": "", - "type": "temp", - } - ] - } - } + # Return preview UI + return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) -NODE_CLASS_MAPPINGS = { - "TrainLoraNode": TrainLoraNode, - "SaveLoRANode": SaveLoRA, - "LoraModelLoader": LoraModelLoader, - "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, - "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode, - "LossGraphNode": LossGraphNode, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "TrainLoraNode": "Train LoRA", - "SaveLoRANode": "Save LoRA Weights", - "LoraModelLoader": "Load LoRA Model", - "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", - "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder", - "LossGraphNode": "Plot Loss Graph", -} + +class TrainingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TrainLoraNode, + LoraModelLoader, + SaveLoRA, + LossGraphNode, + ] + + +async def comfy_entrypoint() -> TrainingExtension: + return TrainingExtension() diff --git a/comfyui_version.py b/comfyui_version.py index fa4b4f4b0..4b039356e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.75" +__version__ = "0.3.76" diff --git a/folder_paths.py b/folder_paths.py index ffdc4d020..9c96540e3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -137,6 +137,71 @@ def set_user_directory(user_dir: str) -> None: user_directory = user_dir +# System User Protection - Protects system directories from HTTP endpoint access +# System Users are internal-only users that cannot be accessed via HTTP endpoints. +# They use the '__' prefix convention (similar to Python's private member convention). +SYSTEM_USER_PREFIX = "__" + + +def get_system_user_directory(name: str = "system") -> str: + """ + Get the path to a System User directory. + + System User directories (prefixed with '__') are only accessible via internal API, + not through HTTP endpoints. Use this for storing system-internal data that + should not be exposed to users. + + Args: + name: System user name (e.g., "system", "cache"). Must be alphanumeric + with underscores allowed, but cannot start with underscore. + + Returns: + Absolute path to the system user directory. + + Raises: + ValueError: If name is empty, invalid, or starts with underscore. + + Example: + >>> get_system_user_directory("cache") + '/path/to/user/__cache' + """ + if not name or not isinstance(name, str): + raise ValueError("System user name cannot be empty") + if not name.replace("_", "").isalnum(): + raise ValueError(f"Invalid system user name: '{name}'") + if name.startswith("_"): + raise ValueError("System user name should not start with underscore") + return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}") + + +def get_public_user_directory(user_id: str) -> str | None: + """ + Get the path to a Public User directory for HTTP endpoint access. + + This function provides structural security by returning None for any + System User (prefixed with '__'). All HTTP endpoints should use this + function instead of directly constructing user paths. + + Args: + user_id: User identifier from HTTP request. + + Returns: + Absolute path to the user directory, or None if user_id is invalid + or refers to a System User. + + Example: + >>> get_public_user_directory("default") + '/path/to/user/default' + >>> get_public_user_directory("__system") + None + """ + if not user_id or not isinstance(user_id, str): + return None + if user_id.startswith(SYSTEM_USER_PREFIX): + return None + return os.path.join(get_user_directory(), user_id) + + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name: str) -> str | None: if type_name == "output": diff --git a/latent_preview.py b/latent_preview.py index ddf6dcf49..66bded4b9 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -2,17 +2,24 @@ import torch from PIL import Image from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD +from comfy.sd import VAE import comfy.model_management import folder_paths import comfy.utils import logging MAX_PREVIEW_RESOLUTION = args.preview_size +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] -def preview_to_image(latent_image): - latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ) +def preview_to_image(latent_image, do_scale=True): + if do_scale: + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ) + else: + latents_ubyte = (latent_image.clamp(0, 1) + .mul(0xFF) # to 0..255 + ) if comfy.model_management.directml_enabled: latents_ubyte = latents_ubyte.to(dtype=torch.uint8) latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) @@ -35,6 +42,10 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) return preview_to_image(x_sample) +class TAEHVPreviewerImpl(TAESDPreviewerImpl): + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decode(x0[:1, :, :1])[0][0] + return preview_to_image(x_sample, do_scale=False) class Latent2RGBPreviewer(LatentPreviewer): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): @@ -81,8 +92,13 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: - taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) - previewer = TAESDPreviewerImpl(taesd) + if latent_format.taesd_decoder_name in VIDEO_TAES: + taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path)) + taesd.first_stage_model.show_progress_bar = False + previewer = TAEHVPreviewerImpl(taesd) + else: + taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) + previewer = TAESDPreviewerImpl(taesd) else: logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) diff --git a/main.py b/main.py index e1b0f1620..0cd815d9e 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags + if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -22,6 +23,23 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + +def handle_comfyui_manager_unavailable(): + if not args.windows_standalone_build: + logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + args.enable_manager = False + + +if args.enable_manager: + if importlib.util.find_spec("comfyui_manager"): + import comfyui_manager + + if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): + handle_comfyui_manager_unavailable() + else: + handle_comfyui_manager_unavailable() + + def apply_custom_paths(): # extra model paths extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") @@ -79,6 +97,11 @@ def execute_prestartup_script(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + continue + if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": continue @@ -101,6 +124,10 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() + +if args.enable_manager: + comfyui_manager.prestartup() + execute_prestartup_script() @@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + if args.enable_manager and not args.disable_manager_ui: + comfyui_manager.start() + hook_breaker_ac10a0.save_functions() asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, diff --git a/manager_requirements.txt b/manager_requirements.txt new file mode 100644 index 000000000..52cc5389c --- /dev/null +++ b/manager_requirements.txt @@ -0,0 +1 @@ +comfyui_manager==4.0.3b3 diff --git a/nodes.py b/nodes.py index f4835c02e..4c910a34b 100644 --- a/nodes.py +++ b/nodes.py @@ -43,6 +43,9 @@ import folder_paths import latent_preview import node_helpers +if args.enable_manager: + import comfyui_manager + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -692,8 +695,10 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod - def vae_list(): + def vae_list(s): vaes = folder_paths.get_filename_list("vae") approx_vaes = folder_paths.get_filename_list("vae_approx") sdxl_taesd_enc = False @@ -722,6 +727,11 @@ class VAELoader: f1_taesd_dec = True elif v.startswith("taef1_decoder."): f1_taesd_enc = True + else: + for tae in s.video_taes: + if v.startswith(tae): + vaes.append(v) + if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: @@ -765,7 +775,7 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(), )}} + return {"required": { "vae_name": (s.vae_list(s), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -776,10 +786,13 @@ class VAELoader: if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) - elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + elif vae_name in self.image_taes: sd = self.load_taesd(vae_name) else: - vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) + if os.path.splitext(vae_name)[0] in self.video_taes: + vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) + else: + vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() @@ -929,7 +942,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2233,6 +2246,12 @@ async def init_external_custom_nodes(): if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue + + if args.enable_manager: + if comfyui_manager.should_be_disabled(module_path): + logging.info(f"Blocked by policy: {module_path}") + continue + time_before = time.perf_counter() success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") node_import_times.append((time.perf_counter() - time_before, module_path, success)) @@ -2278,6 +2297,7 @@ async def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_train.py", + "nodes_dataset.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", diff --git a/pyproject.toml b/pyproject.toml index 9009e65fe..02b94a0ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.75" +version = "0.3.76" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 5f20816d6..f98848e20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.30.6 -comfyui-workflow-templates==0.7.20 +comfyui-frontend-package==1.33.10 +comfyui-workflow-templates==0.7.25 comfyui-embedded-docs==0.3.1 torch torchsde diff --git a/server.py b/server.py index 0fd2e49e3..e3bd056d9 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,9 @@ from protocol import BinaryEventTypes # Import cache control middleware from middleware.cache_middleware import cache_control +if args.enable_manager: + import comfyui_manager + async def send_socket_catch_exception(function, message): try: await function(message) @@ -174,7 +177,7 @@ def create_block_external_middleware(): else: response = await handler(request) - response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" return response return block_external_middleware @@ -212,6 +215,9 @@ class PromptServer(): if args.disable_api_nodes: middlewares.append(create_block_external_middleware()) + if args.enable_manager: + middlewares.append(comfyui_manager.create_middleware()) + max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() @@ -599,7 +605,7 @@ class PromptServer(): system_stats = { "system": { - "os": os.name, + "os": sys.platform, "ram_total": ram_total, "ram_free": ram_free, "comfyui_version": __version__, diff --git a/tests-unit/app_test/user_manager_system_user_test.py b/tests-unit/app_test/user_manager_system_user_test.py new file mode 100644 index 000000000..63b1ac5e5 --- /dev/null +++ b/tests-unit/app_test/user_manager_system_user_test.py @@ -0,0 +1,193 @@ +"""Tests for System User Protection in user_manager.py + +Tests cover: +- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers +- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory() +- add_user(): 3rd defense layer - prevents creation of System User names +- Defense layers integration tests +""" + +import pytest +from unittest.mock import MagicMock, patch +import tempfile + +import folder_paths +from app.user_manager import UserManager + + +@pytest.fixture +def mock_user_directory(): + """Create a temporary user directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(temp_dir) + yield temp_dir + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager(mock_user_directory): + """Create a UserManager instance for testing.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + manager = UserManager() + # Add a default user for testing + manager.users = {"default": "default", "test_user_123": "Test User"} + yield manager + + +@pytest.fixture +def mock_request(): + """Create a mock request object.""" + request = MagicMock() + request.headers = {} + return request + + +class TestGetRequestUserId: + """Tests for get_request_user_id() - 1st defense layer. + + Verifies: + - System Users (__ prefix) in HTTP header are rejected with KeyError + - Public Users pass through successfully + """ + + def test_system_user_raises_error(self, user_manager, mock_request): + """Test System User in header raises KeyError.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_system_user_cache_raises_error(self, user_manager, mock_request): + """Test System User cache raises KeyError.""" + mock_request.headers = {"comfy-user": "__cache"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_normal_user_works(self, user_manager, mock_request): + """Test normal user access works.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + user_id = user_manager.get_request_user_id(mock_request) + assert user_id == "default" + + def test_unknown_user_raises_error(self, user_manager, mock_request): + """Test unknown user raises KeyError.""" + mock_request.headers = {"comfy-user": "unknown_user"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + +class TestGetRequestUserFilepath: + """Tests for get_request_user_filepath() - 2nd defense layer. + + Verifies: + - Returns None when get_public_user_directory() returns None (System User) + - Acts as backup defense if 1st layer is bypassed + """ + + def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory): + """Test System User returns None (structural blocking).""" + # First, we need to mock get_request_user_id to return System User + # But actually, get_request_user_id will raise KeyError first + # So we test via get_public_user_directory returning None + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Patch get_public_user_directory to return None for testing + with patch.object(folder_paths, 'get_public_user_directory', return_value=None): + result = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert result is None + + def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory): + """Test normal user gets valid filepath.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + path = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert path is not None + assert "default" in path + assert path.endswith("test.txt") + + +class TestAddUser: + """Tests for add_user() - 3rd defense layer (creation-time blocking). + + Verifies: + - System User name (__ prefix) creation is rejected with ValueError + - Sanitized usernames that become System User are also rejected + """ + + def test_system_user_prefix_name_raises(self, user_manager): + """Test System User prefix in name raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__system") + + def test_system_user_prefix_cache_raises(self, user_manager): + """Test System User cache prefix raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__cache") + + def test_sanitized_system_user_prefix_raises(self, user_manager): + """Test sanitized name becoming System User prefix raises ValueError (bypass prevention).""" + # "__test" directly starts with System User prefix + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__test") + + def test_normal_user_creation(self, user_manager, mock_user_directory): + """Test normal user creation works.""" + user_id = user_manager.add_user("Normal User") + assert user_id is not None + assert not user_id.startswith("__") + assert "Normal-User" in user_id or "Normal_User" in user_id + + def test_empty_name_raises(self, user_manager): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user("") + + def test_whitespace_only_raises(self, user_manager): + """Test whitespace-only name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user(" ") + + +class TestDefenseLayers: + """Integration tests for all three defense layers. + + Verifies: + - Each defense layer blocks System Users independently + - System User bypass is impossible through any layer + """ + + def test_layer1_get_request_user_id(self, user_manager, mock_request): + """Test 1st defense layer blocks System Users.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError): + user_manager.get_request_user_id(mock_request) + + def test_layer2_get_public_user_directory(self): + """Test 2nd defense layer blocks System Users.""" + result = folder_paths.get_public_user_directory("__system") + assert result is None + + def test_layer3_add_user(self, user_manager): + """Test 3rd defense layer blocks System User creation.""" + with pytest.raises(ValueError): + user_manager.add_user("__system") diff --git a/tests-unit/folder_paths_test/system_user_test.py b/tests-unit/folder_paths_test/system_user_test.py new file mode 100644 index 000000000..cd46459f1 --- /dev/null +++ b/tests-unit/folder_paths_test/system_user_test.py @@ -0,0 +1,206 @@ +"""Tests for System User Protection in folder_paths.py + +Tests cover: +- get_system_user_directory(): Internal API for custom nodes to access System User directories +- get_public_user_directory(): HTTP endpoint access with System User blocking +- Backward compatibility: Existing APIs unchanged +- Security: Path traversal and injection prevention +""" + +import pytest +import os +import tempfile + +from folder_paths import ( + get_system_user_directory, + get_public_user_directory, + get_user_directory, + set_user_directory, +) + + +@pytest.fixture(scope="module") +def mock_user_directory(): + """Create a temporary user directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = get_user_directory() + set_user_directory(temp_dir) + yield temp_dir + set_user_directory(original_dir) + + +class TestGetSystemUserDirectory: + """Tests for get_system_user_directory() - internal API for System User directories. + + Verifies: + - Custom nodes can access System User directories via internal API + - Input validation prevents path traversal attacks + """ + + def test_default_name(self, mock_user_directory): + """Test default 'system' name.""" + path = get_system_user_directory() + assert path.endswith("__system") + assert mock_user_directory in path + + def test_custom_name(self, mock_user_directory): + """Test custom system user name.""" + path = get_system_user_directory("cache") + assert path.endswith("__cache") + assert "__cache" in path + + def test_name_with_underscore(self, mock_user_directory): + """Test name with underscore in middle.""" + path = get_system_user_directory("my_cache") + assert "__my_cache" in path + + def test_empty_name_raises(self): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory("") + + def test_none_name_raises(self): + """Test None name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory(None) + + def test_name_starting_with_underscore_raises(self): + """Test name starting with underscore raises ValueError.""" + with pytest.raises(ValueError, match="should not start with underscore"): + get_system_user_directory("_system") + + def test_path_traversal_raises(self): + """Test path traversal attempt raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("../escape") + + def test_path_traversal_middle_raises(self): + """Test path traversal in middle raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system/../other") + + def test_special_chars_raise(self): + """Test special characters raise ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system!") + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_system_user_directory("test") + assert os.path.isabs(path) + + +class TestGetPublicUserDirectory: + """Tests for get_public_user_directory() - HTTP endpoint access with System User blocking. + + Verifies: + - System Users (__ prefix) return None, blocking HTTP access + - Public Users get valid paths + - New endpoints using this function are automatically protected + """ + + def test_normal_user(self, mock_user_directory): + """Test normal user returns valid path.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + assert mock_user_directory in path + + def test_system_user_returns_none(self): + """Test System User (__ prefix) returns None - blocks HTTP access.""" + assert get_public_user_directory("__system") is None + + def test_system_user_cache_returns_none(self): + """Test System User cache returns None.""" + assert get_public_user_directory("__cache") is None + + def test_empty_user_returns_none(self): + """Test empty user returns None.""" + assert get_public_user_directory("") is None + + def test_none_user_returns_none(self): + """Test None user returns None.""" + assert get_public_user_directory(None) is None + + def test_header_injection_returns_none(self): + """Test header injection attempt returns None (security).""" + assert get_public_user_directory("__system\r\nX-Injected: true") is None + + def test_null_byte_injection_returns_none(self): + """Test null byte injection handling (security).""" + # Note: startswith check happens before any path operations + result = get_public_user_directory("user\x00__system") + # This should return a path since it doesn't start with __ + # The actual security comes from the path not being __* + assert result is not None or result is None # Depends on validation + + def test_path_traversal_attempt(self, mock_user_directory): + """Test path traversal attempt handling.""" + # This function doesn't validate paths, only reserved prefix + # Path traversal should be handled by the caller + path = get_public_user_directory("../../../etc/passwd") + # Returns path but doesn't start with __, so not None + # Actual path validation happens in user_manager + assert path is not None or "__" not in "../../../etc/passwd" + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_public_user_directory("testuser") + assert path is not None + assert os.path.isabs(path) + + +class TestBackwardCompatibility: + """Tests for backward compatibility with existing APIs. + + Verifies: + - get_user_directory() API unchanged + - Existing user data remains accessible + """ + + def test_get_user_directory_unchanged(self, mock_user_directory): + """Test get_user_directory() still works as before.""" + user_dir = get_user_directory() + assert user_dir is not None + assert os.path.isabs(user_dir) + assert user_dir == mock_user_directory + + def test_existing_user_accessible(self, mock_user_directory): + """Test existing users can access their directories.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + + +class TestEdgeCases: + """Tests for edge cases in System User detection. + + Verifies: + - Only __ prefix is blocked (not _, not middle __) + - Bypass attempts are prevented + """ + + def test_prefix_only(self): + """Test prefix-only string is blocked.""" + assert get_public_user_directory("__") is None + + def test_single_underscore_allowed(self): + """Test single underscore prefix is allowed (not System User).""" + path = get_public_user_directory("_system") + assert path is not None + assert "_system" in path + + def test_triple_underscore_blocked(self): + """Test triple underscore is blocked (starts with __).""" + assert get_public_user_directory("___system") is None + + def test_underscore_in_middle_allowed(self): + """Test underscore in middle is allowed.""" + path = get_public_user_directory("my__system") + assert path is not None + assert "my__system" in path + + def test_leading_space_allowed(self): + """Test leading space + prefix is allowed (doesn't start with __).""" + path = get_public_user_directory(" __system") + assert path is not None diff --git a/tests-unit/prompt_server_test/system_user_endpoint_test.py b/tests-unit/prompt_server_test/system_user_endpoint_test.py new file mode 100644 index 000000000..22ac00af9 --- /dev/null +++ b/tests-unit/prompt_server_test/system_user_endpoint_test.py @@ -0,0 +1,375 @@ +"""E2E Tests for System User Protection HTTP Endpoints + +Tests cover: +- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move) +- User creation blocking: System User names cannot be created via POST /users +- Backward compatibility: Public Users work as before +- Custom node scenario: Internal API works while HTTP is blocked +- Structural security: get_public_user_directory() provides automatic protection +""" + +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager +from unittest.mock import patch +import folder_paths + + +@pytest.fixture +def mock_user_directory(tmp_path): + """Create a temporary user directory.""" + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(str(tmp_path)) + yield tmp_path + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager_multi_user(mock_user_directory): + """Create UserManager in multi-user mode.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + um = UserManager() + # Add test users + um.users = {"default": "default", "test_user_123": "Test User"} + yield um + + +@pytest.fixture +def app_multi_user(user_manager_multi_user): + """Create app with multi-user mode enabled.""" + app = web.Application() + routes = web.RouteTableDef() + user_manager_multi_user.add_routes(routes) + app.add_routes(routes) + return app + + +class TestSystemUserEndpointBlocking: + """E2E tests for System User blocking on all HTTP endpoints. + + Verifies: + - GET /userdata blocked for System Users + - POST /userdata blocked for System Users + - DELETE /userdata blocked for System Users + - POST /userdata/.../move/... blocked for System Users + """ + + @pytest.mark.asyncio + async def test_userdata_get_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /userdata with System User header should be blocked. + """ + # Create test directory for System User (simulating internal creation) + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "secret.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Attempt to access System User's data via HTTP + resp = await client.get( + "/userdata?dir=.", + headers={"comfy-user": "__system"} + ) + + # Should be blocked (403 Forbidden or similar error) + assert resp.status in [400, 403, 500], \ + f"System User access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_userdata_post_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/test.txt", + headers={"comfy-user": "__system"}, + data=b"malicious content" + ) + + assert resp.status in [400, 403, 500], \ + f"System User write should be blocked, got {resp.status}" + + # Verify no file was created + assert not (mock_user_directory / "__system" / "test.txt").exists() + + @pytest.mark.asyncio + async def test_userdata_delete_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + DELETE /userdata with System User header should be blocked. + """ + # Create a file in System User directory + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + secret_file = system_user_dir / "secret.txt" + secret_file.write_text("do not delete") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.delete( + "/userdata/secret.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User delete should be blocked, got {resp.status}" + + # Verify file still exists + assert secret_file.exists() + + @pytest.mark.asyncio + async def test_v2_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /v2/userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/v2/userdata", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User v2 access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_move_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata/{file}/move/{dest} with System User header should be blocked. + """ + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "source.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/source.txt/move/dest.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User move should be blocked, got {resp.status}" + + # Verify source file still exists (move was blocked) + assert (system_user_dir / "source.txt").exists() + + +class TestSystemUserCreationBlocking: + """E2E tests for blocking System User name creation via POST /users. + + Verifies: + - POST /users returns 400 for System User name (not 500) + """ + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_name( + self, aiohttp_client, app_multi_user + ): + """POST /users with System User name should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + resp = await client.post( + "/users", + json={"username": "__system"} + ) + + assert resp.status == 400, \ + f"System User creation should return 400, got {resp.status}" + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_prefix_variations( + self, aiohttp_client, app_multi_user + ): + """POST /users with any System User prefix variation should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + system_user_names = ["__system", "__cache", "__config", "__anything"] + + for name in system_user_names: + resp = await client.post("/users", json={"username": name}) + assert resp.status == 400, \ + f"System User name '{name}' should return 400, got {resp.status}" + + +class TestPublicUserStillWorks: + """E2E tests for backward compatibility - Public Users should work as before. + + Verifies: + - Public Users can access their data via HTTP + - Public Users can create files via HTTP + """ + + @pytest.mark.asyncio + async def test_public_user_can_access_userdata( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to access their data. + """ + # Create test directory for Public User + user_dir = mock_user_directory / "default" + user_dir.mkdir() + test_dir = user_dir / "workflows" + test_dir.mkdir() + (test_dir / "test.json").write_text('{"test": true}') + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata?dir=workflows", + headers={"comfy-user": "default"} + ) + + assert resp.status == 200 + data = await resp.json() + assert "test.json" in data + + @pytest.mark.asyncio + async def test_public_user_can_create_files( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to create files. + """ + # Create user directory + user_dir = mock_user_directory / "default" + user_dir.mkdir() + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/newfile.txt", + headers={"comfy-user": "default"}, + data=b"user content" + ) + + assert resp.status == 200 + assert (user_dir / "newfile.txt").exists() + + +class TestCustomNodeScenario: + """Tests for custom node use case: internal API access vs HTTP blocking. + + Verifies: + - Internal API (get_system_user_directory) works for custom nodes + - HTTP endpoint cannot access data created via internal API + """ + + def test_internal_api_can_access_system_user(self, mock_user_directory): + """ + Internal API (get_system_user_directory) should work for custom nodes. + """ + # Custom node uses internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + + assert system_path is not None + assert "__mynode_config" in system_path + + # Can create and write to System User directory + os.makedirs(system_path, exist_ok=True) + config_file = os.path.join(system_path, "settings.json") + with open(config_file, "w") as f: + f.write('{"api_key": "secret"}') + + assert os.path.exists(config_file) + + @pytest.mark.asyncio + async def test_http_cannot_access_internal_data( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + HTTP endpoint cannot access data created via internal API. + """ + # Custom node creates data via internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + os.makedirs(system_path, exist_ok=True) + with open(os.path.join(system_path, "secret.json"), "w") as f: + f.write('{"api_key": "secret"}') + + client = await aiohttp_client(app_multi_user) + + # Attacker tries to access via HTTP + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata/secret.json", + headers={"comfy-user": "__mynode_config"} + ) + + # Should be blocked + assert resp.status in [400, 403, 500] + + +class TestStructuralSecurity: + """Tests for structural security pattern. + + Verifies: + - get_public_user_directory() automatically blocks System Users + - New endpoints using this function are automatically protected + """ + + def test_get_public_user_directory_blocks_system_user(self): + """ + Any code using get_public_user_directory() is automatically protected. + """ + # This is the structural security - any new endpoint using this function + # will automatically block System Users + assert folder_paths.get_public_user_directory("__system") is None + assert folder_paths.get_public_user_directory("__cache") is None + assert folder_paths.get_public_user_directory("__anything") is None + + # Public Users work + assert folder_paths.get_public_user_directory("default") is not None + assert folder_paths.get_public_user_directory("user123") is not None + + def test_structural_security_pattern(self, mock_user_directory): + """ + Demonstrate the structural security pattern for new endpoints. + + Any new endpoint should follow this pattern: + 1. Get user from request + 2. Use get_public_user_directory() - automatically blocks System Users + 3. If None, return error + """ + def new_endpoint_handler(user_id: str) -> str | None: + """Example of how new endpoints should be implemented.""" + user_path = folder_paths.get_public_user_directory(user_id) + if user_path is None: + return None # Blocked + return user_path + + # System Users are automatically blocked + assert new_endpoint_handler("__system") is None + assert new_endpoint_handler("__secret") is None + + # Public Users work + assert new_endpoint_handler("default") is not None