diff --git a/scanner.py b/scanner.py index 711013b5..36e59323 100644 --- a/scanner.py +++ b/scanner.py @@ -7,6 +7,7 @@ import concurrent import datetime import concurrent.futures import requests +import warnings builtin_nodes = set() @@ -74,10 +75,13 @@ def extract_nodes(code_text): parse_cnt += 1 code_text = re.sub(r'\\[^"\']', '', code_text) - parsed_code = ast.parse(code_text) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=SyntaxWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + parsed_code = ast.parse(code_text) assignments = (node for node in parsed_code.body if isinstance(node, ast.Assign)) - + for assignment in assignments: if isinstance(assignment.targets[0], ast.Name) and assignment.targets[0].id in ['NODE_CONFIG', 'NODE_CLASS_MAPPINGS']: node_class_mappings = assignment.value @@ -91,7 +95,7 @@ def extract_nodes(code_text): for key in node_class_mappings.keys: if key is not None and isinstance(key.value, str): s.add(key.value.strip()) - + return s else: return set() @@ -99,6 +103,99 @@ def extract_nodes(code_text): return set() +def has_comfy_node_base(class_node): + """Check if class inherits from io.ComfyNode or ComfyNode""" + for base in class_node.bases: + # Case 1: ComfyNode + if isinstance(base, ast.Name) and base.id == 'ComfyNode': + return True + # Case 2: io.ComfyNode + elif isinstance(base, ast.Attribute): + if base.attr == 'ComfyNode': + return True + return False + + +def extract_keyword_value(call_node, keyword): + """ + Extract string value of keyword argument + Schema(node_id="MyNode") -> "MyNode" + """ + for kw in call_node.keywords: + if kw.arg == keyword: + # ast.Constant (Python 3.8+) + if isinstance(kw.value, ast.Constant): + if isinstance(kw.value.value, str): + return kw.value.value + # ast.Str (Python 3.7-) - suppress deprecation warning + else: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + if hasattr(ast, 'Str') and isinstance(kw.value, ast.Str): + return kw.value.s + return None + + +def is_schema_call(call_node): + """Check if ast.Call is io.Schema() or Schema()""" + func = call_node.func + if isinstance(func, ast.Name) and func.id == 'Schema': + return True + elif isinstance(func, ast.Attribute) and func.attr == 'Schema': + return True + return False + + +def extract_node_id_from_schema(class_node): + """ + Extract node_id from define_schema() method + """ + for item in class_node.body: + if isinstance(item, ast.FunctionDef) and item.name == 'define_schema': + # Walk through function body + for stmt in ast.walk(item): + if isinstance(stmt, ast.Call): + # Check if it's Schema() call + if is_schema_call(stmt): + node_id = extract_keyword_value(stmt, 'node_id') + if node_id: + return node_id + return None + + +def extract_v3_nodes(code_text): + """ + Extract V3 node IDs using AST parsing + Returns: set of node_id strings + """ + global parse_cnt + + try: + if parse_cnt % 100 == 0: + print(".", end="", flush=True) + parse_cnt += 1 + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=SyntaxWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + tree = ast.parse(code_text) + except (SyntaxError, UnicodeDecodeError): + return set() + + nodes = set() + + # Find io.ComfyNode subclasses + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Check if inherits from ComfyNode + if has_comfy_node_base(node): + node_id = extract_node_id_from_schema(node) + if node_id: + nodes.add(node_id) + + return nodes + + # scan def scan_in_file(filename, is_builtin=False): global builtin_nodes @@ -112,7 +209,11 @@ def scan_in_file(filename, is_builtin=False): nodes = set() class_dict = {} + # V1 nodes detection nodes |= extract_nodes(code) + + # V3 nodes detection + nodes |= extract_v3_nodes(code) code = re.sub(r'^#.*?$', '', code, flags=re.MULTILINE) def extract_keys(pattern, code):