[fix] Fixes non-async public API access (#10857)

It looks like the synchronous version of the public API broke due to an
addition of `from __future__ import annotations`. This change updates
the async-to-sync adapter to work with both types of type annotations.
This commit is contained in:
guill 2025-11-23 22:56:20 -08:00 committed by GitHub
parent cbd68e3d58
commit f66183a541
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 184 additions and 16 deletions

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args
from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker:
@ -220,7 +220,14 @@ class AsyncToSyncConverter:
self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy
# Get all annotations from the class hierarchy and resolve string annotations
try:
# get_type_hints resolves string annotations to actual type objects
# This handles classes using 'from __future__ import annotations'
all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations if get_type_hints fails
# (e.g., for undefined forward references)
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
"""Extract class attributes that are classes themselves."""
class_attributes = []
# Get resolved type hints to handle string annotations
try:
type_hints = get_type_hints(async_class)
except Exception:
type_hints = {}
# Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr))
elif (
hasattr(async_class, "__annotations__")
and name in async_class.__annotations__
):
annotation = async_class.__annotations__[name]
elif name in type_hints:
# Use resolved type hint instead of raw annotation
annotation = type_hints[name]
if isinstance(annotation, type):
class_attributes.append((name, annotation))
@ -908,7 +919,11 @@ class AsyncToSyncConverter:
attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy
# Resolve string annotations to actual types
try:
all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):

View File

@ -0,0 +1,153 @@
"""
Tests for public ComfyAPI and ComfyAPISync functions.
These tests verify that the public API methods work correctly in both sync and async contexts,
ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly
handles string annotations from 'from __future__ import annotations'.
"""
import pytest
import time
import subprocess
import torch
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import ComfyClient
@pytest.mark.execution
class TestPublicAPI:
"""Test suite for public ComfyAPI and ComfyAPISync methods."""
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
"""Start ComfyUI server for testing."""
pargs = [
'python', 'main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
"""Create shared client with connection retry."""
client = ComfyClient()
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
break
except ConnectionRefusedError:
if i == n_tries - 1:
raise
yield client
del client
torch.cuda.empty_cache()
@fixture
def client(self, shared_client, request):
"""Set test name for each test."""
shared_client.set_test_name(f"public_api[{request.node.name}]")
yield shared_client
@fixture
def builder(self, request):
"""Create GraphBuilder for each test."""
yield GraphBuilder(prefix=request.node.name)
def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
"""Test that TestSyncProgressUpdate executes without errors.
This test validates that api_sync.execution.set_progress() works correctly,
which is the primary code path fixed by adding get_type_hints() to async_to_sync.py.
"""
g = builder
image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
# Use TestSyncProgressUpdate with short sleep
progress_node = g.node("TestSyncProgressUpdate",
value=image.out(0),
sleep_seconds=0.5)
output = g.node("SaveImage", images=progress_node.out(0))
# Execute workflow
result = client.run(g)
# Verify execution
assert result.did_run(progress_node), "Progress node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify output
images = result.get_images(output)
assert len(images) == 1, "Should have produced 1 image"
def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
"""Test that TestAsyncProgressUpdate executes without errors.
This test validates that await api.execution.set_progress() works correctly
in async contexts.
"""
g = builder
image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
# Use TestAsyncProgressUpdate with short sleep
progress_node = g.node("TestAsyncProgressUpdate",
value=image.out(0),
sleep_seconds=0.5)
output = g.node("SaveImage", images=progress_node.out(0))
# Execute workflow
result = client.run(g)
# Verify execution
assert result.did_run(progress_node), "Async progress node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify output
images = result.get_images(output)
assert len(images) == 1, "Should have produced 1 image"
def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder):
"""Test both sync and async progress updates in same workflow.
This test ensures that both ComfyAPISync and ComfyAPI can coexist and work
correctly in the same workflow execution.
"""
g = builder
image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
# Use both types of progress nodes
sync_progress = g.node("TestSyncProgressUpdate",
value=image1.out(0),
sleep_seconds=0.3)
async_progress = g.node("TestAsyncProgressUpdate",
value=image2.out(0),
sleep_seconds=0.3)
# Create outputs
output1 = g.node("SaveImage", images=sync_progress.out(0))
output2 = g.node("SaveImage", images=async_progress.out(0))
# Execute workflow
result = client.run(g)
# Both should execute successfully
assert result.did_run(sync_progress), "Sync progress node should have executed"
assert result.did_run(async_progress), "Async progress node should have executed"
assert result.did_run(output1), "First output node should have executed"
assert result.did_run(output2), "Second output node should have executed"
# Verify outputs
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images1) == 1, "Should have produced 1 image from sync node"
assert len(images2) == 1, "Should have produced 1 image from async node"