diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index 42f1dbe99..5d665d6af 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -28,7 +28,7 @@ jobs: - name: Start ComfyUI server run: | python main.py --cpu 2>&1 | tee console_output.log & - wait-for-it --service 127.0.0.1:8188 -t 600 + wait-for-it --service 127.0.0.1:8188 -t 30 working-directory: ComfyUI - name: Check for unhandled exceptions in server log run: | diff --git a/comfy/comfy_types/README.md b/comfy/comfy_types/README.md new file mode 100644 index 000000000..869851e7c --- /dev/null +++ b/comfy/comfy_types/README.md @@ -0,0 +1,43 @@ +# Comfy Typing +## Type hinting for ComfyUI Node development + +This module provides type hinting and concrete convenience types for node developers. +If cloned to the custom_nodes directory of ComfyUI, types can be imported using: + +```python +from comfy_types import IO, ComfyNodeABC, CheckLazyMixin + +class ExampleNode(ComfyNodeABC): + @classmethod + def INPUT_TYPES(s) -> InputTypeDict: + return {"required": {}} +``` + +Full example is in [examples/example_nodes.py](examples/example_nodes.py). + +# Types +A few primary types are documented below. More complete information is available via the docstrings on each type. + +## `IO` + +A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing: + +- `ANY`: `"*"` +- `NUMBER`: `"FLOAT,INT"` +- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"` + +## `ComfyNodeABC` + +An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings. + +### Type hinting for `INPUT_TYPES` + +![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png) + +### `INPUT_TYPES` return dict + +![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png) + +### Options for individual inputs + +![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png) diff --git a/comfy/comfy_types.py b/comfy/comfy_types/__init__.py similarity index 75% rename from comfy/comfy_types.py rename to comfy/comfy_types/__init__.py index 70cf4b158..19ec33f98 100644 --- a/comfy/comfy_types.py +++ b/comfy/comfy_types/__init__.py @@ -1,5 +1,6 @@ import torch from typing import Callable, Protocol, TypedDict, Optional, List +from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin class UnetApplyFunction(Protocol): @@ -30,3 +31,15 @@ class UnetParams(TypedDict): UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor] + + +__all__ = [ + "UnetWrapperFunction", + UnetApplyConds.__name__, + UnetParams.__name__, + UnetApplyFunction.__name__, + IO.__name__, + InputTypeDict.__name__, + ComfyNodeABC.__name__, + CheckLazyMixin.__name__, +] diff --git a/comfy/comfy_types/examples/example_nodes.py b/comfy/comfy_types/examples/example_nodes.py new file mode 100644 index 000000000..b6465f39e --- /dev/null +++ b/comfy/comfy_types/examples/example_nodes.py @@ -0,0 +1,28 @@ +from comfy_types import IO, ComfyNodeABC, InputTypeDict +from inspect import cleandoc + + +class ExampleNode(ComfyNodeABC): + """An example node that just adds 1 to an input integer. + + * Requires an IDE configured with analysis paths etc to be worth looking at. + * Not intended for use in ComfyUI. + """ + + DESCRIPTION = cleandoc(__doc__) + CATEGORY = "examples" + + @classmethod + def INPUT_TYPES(s) -> InputTypeDict: + return { + "required": { + "input_int": (IO.INT, {"defaultInput": True}), + } + } + + RETURN_TYPES = (IO.INT,) + RETURN_NAMES = ("input_plus_one",) + FUNCTION = "execute" + + def execute(self, input_int: int): + return (input_int + 1,) diff --git a/comfy/comfy_types/examples/input_options.png b/comfy/comfy_types/examples/input_options.png new file mode 100644 index 000000000..ac859bbc0 Binary files /dev/null and b/comfy/comfy_types/examples/input_options.png differ diff --git a/comfy/comfy_types/examples/input_types.png b/comfy/comfy_types/examples/input_types.png new file mode 100644 index 000000000..27e031ccf Binary files /dev/null and b/comfy/comfy_types/examples/input_types.png differ diff --git a/comfy/comfy_types/examples/required_hint.png b/comfy/comfy_types/examples/required_hint.png new file mode 100644 index 000000000..22c0182a0 Binary files /dev/null and b/comfy/comfy_types/examples/required_hint.png differ diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py new file mode 100644 index 000000000..056b1aa65 --- /dev/null +++ b/comfy/comfy_types/node_typing.py @@ -0,0 +1,274 @@ +"""Comfy-specific type hinting""" + +from __future__ import annotations +from typing import Literal, TypedDict +from abc import ABC, abstractmethod +from enum import Enum + + +class StrEnum(str, Enum): + """Base class for string enums. Python's StrEnum is not available until 3.11.""" + + def __str__(self) -> str: + return self.value + + +class IO(StrEnum): + """Node input/output data types. + + Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``. + """ + + STRING = "STRING" + IMAGE = "IMAGE" + MASK = "MASK" + LATENT = "LATENT" + BOOLEAN = "BOOLEAN" + INT = "INT" + FLOAT = "FLOAT" + CONDITIONING = "CONDITIONING" + SAMPLER = "SAMPLER" + SIGMAS = "SIGMAS" + GUIDER = "GUIDER" + NOISE = "NOISE" + CLIP = "CLIP" + CONTROL_NET = "CONTROL_NET" + VAE = "VAE" + MODEL = "MODEL" + CLIP_VISION = "CLIP_VISION" + CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" + STYLE_MODEL = "STYLE_MODEL" + GLIGEN = "GLIGEN" + UPSCALE_MODEL = "UPSCALE_MODEL" + AUDIO = "AUDIO" + WEBCAM = "WEBCAM" + POINT = "POINT" + FACE_ANALYSIS = "FACE_ANALYSIS" + BBOX = "BBOX" + SEGS = "SEGS" + + ANY = "*" + """Always matches any type, but at a price. + + Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible. + """ + NUMBER = "FLOAT,INT" + """A float or an int - could be either""" + PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN" + """Could be any of: string, float, int, or bool""" + + def __ne__(self, value: object) -> bool: + if self == "*" or value == "*": + return False + if not isinstance(value, str): + return True + a = frozenset(self.split(",")) + b = frozenset(value.split(",")) + return not (b.issubset(a) or a.issubset(b)) + + +class InputTypeOptions(TypedDict): + """Provides type hinting for the return type of the INPUT_TYPES node function. + + Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`). + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes + """ + + default: bool | str | float | int | list | tuple + """The default value of the widget""" + defaultInput: bool + """Defaults to an input slot rather than a widget""" + forceInput: bool + """`defaultInput` and also don't allow converting to a widget""" + lazy: bool + """Declares that this input uses lazy evaluation""" + rawLink: bool + """When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion.""" + tooltip: str + """Tooltip for the input (or widget), shown on pointer hover""" + # class InputTypeNumber(InputTypeOptions): + # default: float | int + min: float + """The minimum value of a number (``FLOAT`` | ``INT``)""" + max: float + """The maximum value of a number (``FLOAT`` | ``INT``)""" + step: float + """The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)""" + round: float + """Floats are rounded by this value (``FLOAT``)""" + # class InputTypeBoolean(InputTypeOptions): + # default: bool + label_on: str + """The label to use in the UI when the bool is True (``BOOLEAN``)""" + label_on: str + """The label to use in the UI when the bool is False (``BOOLEAN``)""" + # class InputTypeString(InputTypeOptions): + # default: str + multiline: bool + """Use a multiline text box (``STRING``)""" + placeholder: str + """Placeholder text to display in the UI when empty (``STRING``)""" + # Deprecated: + # defaultVal: str + dynamicPrompts: bool + """Causes the front-end to evaluate dynamic prompts (``STRING``)""" + + +class HiddenInputTypeDict(TypedDict): + """Provides type hinting for the hidden entry of node INPUT_TYPES.""" + + node_id: Literal["UNIQUE_ID"] + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + unique_id: Literal["UNIQUE_ID"] + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + prompt: Literal["PROMPT"] + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + extra_pnginfo: Literal["EXTRA_PNGINFO"] + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + dynprompt: Literal["DYNPROMPT"] + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + + +class InputTypeDict(TypedDict): + """Provides type hinting for node INPUT_TYPES. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs + """ + + required: dict[str, tuple[IO, InputTypeOptions]] + """Describes all inputs that must be connected for the node to execute.""" + optional: dict[str, tuple[IO, InputTypeOptions]] + """Describes inputs which do not need to be connected.""" + hidden: HiddenInputTypeDict + """Offers advanced functionality and server-client communication. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs + """ + + +class ComfyNodeABC(ABC): + """Abstract base class for Comfy nodes. Includes the names and expected types of attributes. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview + """ + + DESCRIPTION: str + """Node description, shown as a tooltip when hovering over the node. + + Usage:: + + # Explicitly define the description + DESCRIPTION = "Example description here." + + # Use the docstring of the node class. + DESCRIPTION = cleandoc(__doc__) + """ + CATEGORY: str + """The category of the node, as per the "Add Node" menu. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category + """ + EXPERIMENTAL: bool + """Flags a node as experimental, informing users that it may change or not work as expected.""" + DEPRECATED: bool + """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + + @classmethod + @abstractmethod + def INPUT_TYPES(s) -> InputTypeDict: + """Defines node inputs. + + * Must include the ``required`` key, which describes all inputs that must be connected for the node to execute. + * The ``optional`` key can be added to describe inputs which do not need to be connected. + * The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types + """ + return {"required": {}} + + OUTPUT_NODE: bool + """Flags this node as an output node, causing any inputs it requires to be executed. + + If a node is not connected to any output nodes, that node will not be executed. Usage:: + + OUTPUT_NODE = True + + From the docs: + + By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node + """ + INPUT_IS_LIST: bool + """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. + + All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``. + + From the docs: + + A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing + """ + OUTPUT_IS_LIST: tuple[bool] + """A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items. + + Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list. + + A ``tuple[bool]``, where the items match those in `RETURN_TYPES`:: + + RETURN_TYPES = (IO.INT, IO.INT, IO.STRING) + OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally + + From the docs: + + In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing, + the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`, + specifying which outputs which should be so treated. + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing + """ + + RETURN_TYPES: tuple[IO] + """A tuple representing the outputs of this node. + + Usage:: + + RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE") + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types + """ + RETURN_NAMES: tuple[str] + """The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")`` + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names + """ + OUTPUT_TOOLTIPS: tuple[str] + """A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`.""" + FUNCTION: str + """The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"` + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function + """ + + +class CheckLazyMixin: + """Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs.""" + + def check_lazy_status(self, **kwargs) -> list[str]: + """Returns a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status + """ + + need = [name for name in kwargs if kwargs[name] is None] + return need diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py new file mode 100644 index 000000000..cec105fc9 --- /dev/null +++ b/comfy_execution/validation.py @@ -0,0 +1,39 @@ +from __future__ import annotations + + +def validate_node_input( + received_type: str, input_type: str, strict: bool = False +) -> bool: + """ + received_type and input_type are both strings of the form "T1,T2,...". + + If strict is True, the input_type must contain the received_type. + For example, if received_type is "STRING" and input_type is "STRING,INT", + this will return True. But if received_type is "STRING,INT" and input_type is + "INT", this will return False. + + If strict is False, the input_type must have overlap with the received_type. + For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT", + this will return True. + + Supports pre-union type extension behaviour of ``__ne__`` overrides. + """ + # If the types are exactly the same, we can return immediately + # Use pre-union behaviour: inverse of `__ne__` + if not received_type != input_type: + return True + + # Not equal, and not strings + if not isinstance(received_type, str) or not isinstance(input_type, str): + return False + + # Split the type strings into sets for comparison + received_types = set(t.strip() for t in received_type.split(",")) + input_types = set(t.strip() for t in input_type.split(",")) + + if strict: + # In strict mode, all received types must be in the input types + return received_types.issubset(input_types) + else: + # In non-strict mode, there must be at least one type in common + return len(received_types.intersection(input_types)) > 0 diff --git a/execution.py b/execution.py index 768e35abc..929ef85fa 100644 --- a/execution.py +++ b/execution.py @@ -16,6 +16,7 @@ import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.validation import validate_node_input from comfy.cli_args import args class ExecutionResult(Enum): @@ -527,7 +528,6 @@ class PromptExecutor: comfy.model_management.unload_all_models() - def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and received_type != type_input: - details = f"{x}, {received_type} != {type_input}" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): + details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", diff --git a/nodes.py b/nodes.py index 260bb5e15..1cb4b5a5a 100644 --- a/nodes.py +++ b/nodes.py @@ -1,3 +1,4 @@ +from __future__ import annotations import torch import os @@ -24,6 +25,7 @@ import comfy.sample import comfy.sd import comfy.utils import comfy.controlnet +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict import comfy.clip_vision @@ -44,16 +46,16 @@ def interrupt_processing(value=True): MAX_RESOLUTION=16384 -class CLIPTextEncode: +class CLIPTextEncode(ComfyNodeABC): @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(s) -> InputTypeDict: return { "required": { - "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), - "clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."}) + "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), + "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) } } - RETURN_TYPES = ("CONDITIONING",) + RETURN_TYPES = (IO.CONDITIONING,) OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) FUNCTION = "encode" diff --git a/tests-unit/execution_test/validate_node_input_test.py b/tests-unit/execution_test/validate_node_input_test.py new file mode 100644 index 000000000..85a0c9601 --- /dev/null +++ b/tests-unit/execution_test/validate_node_input_test.py @@ -0,0 +1,119 @@ +import pytest +from comfy_execution.validation import validate_node_input + + +def test_exact_match(): + """Test cases where types match exactly""" + assert validate_node_input("STRING", "STRING") + assert validate_node_input("STRING,INT", "STRING,INT") + assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter + + +def test_strict_mode(): + """Test strict mode validation""" + # Should pass - received type is subset of input type + assert validate_node_input("STRING", "STRING,INT", strict=True) + assert validate_node_input("INT", "STRING,INT", strict=True) + assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True) + + # Should fail - received type is not subset of input type + assert not validate_node_input("STRING,INT", "STRING", strict=True) + assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True) + assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True) + + +def test_non_strict_mode(): + """Test non-strict mode validation (default behavior)""" + # Should pass - types have overlap + assert validate_node_input("STRING,BOOLEAN", "STRING,INT") + assert validate_node_input("STRING,INT", "INT,BOOLEAN") + assert validate_node_input("STRING", "STRING,INT") + + # Should fail - no overlap in types + assert not validate_node_input("BOOLEAN", "STRING,INT") + assert not validate_node_input("FLOAT", "STRING,INT") + assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT") + + +def test_whitespace_handling(): + """Test that whitespace is handled correctly""" + assert validate_node_input("STRING, INT", "STRING,INT") + assert validate_node_input("STRING,INT", "STRING, INT") + assert validate_node_input(" STRING , INT ", "STRING,INT") + assert validate_node_input("STRING,INT", " STRING , INT ") + + +def test_empty_strings(): + """Test behavior with empty strings""" + assert validate_node_input("", "") + assert not validate_node_input("STRING", "") + assert not validate_node_input("", "STRING") + + +def test_single_vs_multiple(): + """Test single type against multiple types""" + assert validate_node_input("STRING", "STRING,INT,BOOLEAN") + assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False) + assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True) + + +def test_non_string(): + """Test non-string types""" + obj1 = object() + obj2 = object() + assert validate_node_input(obj1, obj1) + assert not validate_node_input(obj1, obj2) + + +class NotEqualsOverrideTest(str): + """Test class for ``__ne__`` override.""" + + def __ne__(self, value: object) -> bool: + if self == "*" or value == "*": + return False + if self == "LONGER_THAN_2": + return not len(value) > 2 + raise TypeError("This is a class for unit tests only.") + + +def test_ne_override(): + """Test ``__ne__`` any override""" + any = NotEqualsOverrideTest("*") + invalid_type = "INVALID_TYPE" + obj = object() + assert validate_node_input(any, any) + assert validate_node_input(any, invalid_type) + assert validate_node_input(any, obj) + assert validate_node_input(any, {}) + assert validate_node_input(any, []) + assert validate_node_input(any, [1, 2, 3]) + + +def test_ne_custom_override(): + """Test ``__ne__`` custom override""" + special = NotEqualsOverrideTest("LONGER_THAN_2") + + assert validate_node_input(special, special) + assert validate_node_input(special, "*") + assert validate_node_input(special, "INVALID_TYPE") + assert validate_node_input(special, [1, 2, 3]) + + # Should fail + assert not validate_node_input(special, [1, 2]) + assert not validate_node_input(special, "TY") + + +@pytest.mark.parametrize( + "received,input_type,strict,expected", + [ + ("STRING", "STRING", False, True), + ("STRING,INT", "STRING,INT", False, True), + ("STRING", "STRING,INT", True, True), + ("STRING,INT", "STRING", True, False), + ("BOOLEAN", "STRING,INT", False, False), + ("STRING,BOOLEAN", "STRING,INT", False, True), + ], +) +def test_parametrized_cases(received, input_type, strict, expected): + """Parametrized test cases for various scenarios""" + assert validate_node_input(received, input_type, strict) == expected