diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 0b701260f..7f378da72 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -880,9 +880,9 @@ class ComboDynamic(ComfyTypeI): @comfytype(io_type="COMFY_MATCHTYPE_V3") class MatchType(ComfyTypeIO): class Template: - def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType): self.template_id = template_id - self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + self.allowed_types = [allowed_types] if issubclass(allowed_types, _ComfyType) else allowed_types def as_dict(self): return { @@ -979,6 +979,7 @@ class NodeInfoV1: output_is_list: list[bool]=None output_name: list[str]=None output_tooltips: list[str]=None + output_matchtypes: list[str]=None name: str=None display_name: str=None description: str=None @@ -1118,12 +1119,24 @@ class Schema: output_is_list = [] output_name = [] output_tooltips = [] + output_matchtypes = [] + any_matchtypes = False if self.outputs: for o in self.outputs: output.append(o.io_type) output_is_list.append(o.is_output_list) output_name.append(o.display_name if o.display_name else o.io_type) output_tooltips.append(o.tooltip if o.tooltip else None) + # special handling for MatchType + if isinstance(o, MatchType.Output): + output_matchtypes.append(o.template.template_id) + any_matchtypes = True + else: + output_matchtypes.append(None) + + # clear out lists that are all None + if not any_matchtypes: + output_matchtypes = None info = NodeInfoV1( input=input, @@ -1132,6 +1145,7 @@ class Schema: output_is_list=output_is_list, output_name=output_name, output_tooltips=output_tooltips, + output_matchtypes=output_matchtypes, name=self.node_id, display_name=self.display_name, category=self.category, @@ -1646,6 +1660,8 @@ __all__ = [ "SEGS", "AnyType", "MultiType", + # Dynamic Types + "MatchType", # Other classes "HiddenHolder", "Hidden", diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py new file mode 100644 index 000000000..35bec72ea --- /dev/null +++ b/comfy_extras/nodes_logic.py @@ -0,0 +1,44 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + + +class SwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySwitchNode", + display_name="Switch", + category="logic", + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True), + io.MatchType.Input("on_true", template=template, lazy=True), + ], + outputs=[ + io.MatchType.Output("output", template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=None, on_true=None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def execute(cls, switch, on_true, on_false) -> io.NodeOutput: + return io.NodeOutput(on_true if switch else on_false) + + +class LogicExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SwitchNode, + ] + +async def comfy_entrypoint() -> LogicExtension: + return LogicExtension() diff --git a/nodes.py b/nodes.py index 5689f6fe1..6c4783fa4 100644 --- a/nodes.py +++ b/nodes.py @@ -2330,6 +2330,7 @@ async def init_builtin_extra_nodes(): "nodes_easycache.py", "nodes_audio_encoder.py", "nodes_rope.py", + "nodes_logic.py", ] import_failed = []