mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
[fix] Fixes non-async public API access (#10857)
It looks like the synchronous version of the public API broke due to an addition of `from __future__ import annotations`. This change updates the async-to-sync adapter to work with both types of type annotations.
This commit is contained in:
parent
cbd68e3d58
commit
f66183a541
@ -8,7 +8,7 @@ import os
|
|||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Type, get_origin, get_args
|
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
class TypeTracker:
|
class TypeTracker:
|
||||||
@ -220,7 +220,14 @@ class AsyncToSyncConverter:
|
|||||||
self._async_instance = async_class(*args, **kwargs)
|
self._async_instance = async_class(*args, **kwargs)
|
||||||
|
|
||||||
# Handle annotated class attributes (like execution: Execution)
|
# Handle annotated class attributes (like execution: Execution)
|
||||||
# Get all annotations from the class hierarchy
|
# Get all annotations from the class hierarchy and resolve string annotations
|
||||||
|
try:
|
||||||
|
# get_type_hints resolves string annotations to actual type objects
|
||||||
|
# This handles classes using 'from __future__ import annotations'
|
||||||
|
all_annotations = get_type_hints(async_class)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to raw annotations if get_type_hints fails
|
||||||
|
# (e.g., for undefined forward references)
|
||||||
all_annotations = {}
|
all_annotations = {}
|
||||||
for base_class in reversed(inspect.getmro(async_class)):
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
if hasattr(base_class, "__annotations__"):
|
if hasattr(base_class, "__annotations__"):
|
||||||
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
|
|||||||
"""Extract class attributes that are classes themselves."""
|
"""Extract class attributes that are classes themselves."""
|
||||||
class_attributes = []
|
class_attributes = []
|
||||||
|
|
||||||
|
# Get resolved type hints to handle string annotations
|
||||||
|
try:
|
||||||
|
type_hints = get_type_hints(async_class)
|
||||||
|
except Exception:
|
||||||
|
type_hints = {}
|
||||||
|
|
||||||
# Look for class attributes that are classes
|
# Look for class attributes that are classes
|
||||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||||
if isinstance(attr, type) and not name.startswith("_"):
|
if isinstance(attr, type) and not name.startswith("_"):
|
||||||
class_attributes.append((name, attr))
|
class_attributes.append((name, attr))
|
||||||
elif (
|
elif name in type_hints:
|
||||||
hasattr(async_class, "__annotations__")
|
# Use resolved type hint instead of raw annotation
|
||||||
and name in async_class.__annotations__
|
annotation = type_hints[name]
|
||||||
):
|
|
||||||
annotation = async_class.__annotations__[name]
|
|
||||||
if isinstance(annotation, type):
|
if isinstance(annotation, type):
|
||||||
class_attributes.append((name, annotation))
|
class_attributes.append((name, annotation))
|
||||||
|
|
||||||
@ -908,7 +919,11 @@ class AsyncToSyncConverter:
|
|||||||
attribute_mappings = {}
|
attribute_mappings = {}
|
||||||
|
|
||||||
# First check annotations for typed attributes (including from parent classes)
|
# First check annotations for typed attributes (including from parent classes)
|
||||||
# Collect all annotations from the class hierarchy
|
# Resolve string annotations to actual types
|
||||||
|
try:
|
||||||
|
all_annotations = get_type_hints(async_class)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to raw annotations
|
||||||
all_annotations = {}
|
all_annotations = {}
|
||||||
for base_class in reversed(inspect.getmro(async_class)):
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
if hasattr(base_class, "__annotations__"):
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
|||||||
153
tests/execution/test_public_api.py
Normal file
153
tests/execution/test_public_api.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
Tests for public ComfyAPI and ComfyAPISync functions.
|
||||||
|
|
||||||
|
These tests verify that the public API methods work correctly in both sync and async contexts,
|
||||||
|
ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly
|
||||||
|
handles string annotations from 'from __future__ import annotations'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import subprocess
|
||||||
|
import torch
|
||||||
|
from pytest import fixture
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
from tests.execution.test_execution import ComfyClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.execution
|
||||||
|
class TestPublicAPI:
|
||||||
|
"""Test suite for public ComfyAPI and ComfyAPISync methods."""
|
||||||
|
|
||||||
|
@fixture(scope="class", autouse=True)
|
||||||
|
def _server(self, args_pytest):
|
||||||
|
"""Start ComfyUI server for testing."""
|
||||||
|
pargs = [
|
||||||
|
'python', 'main.py',
|
||||||
|
'--output-directory', args_pytest["output_dir"],
|
||||||
|
'--listen', args_pytest["listen"],
|
||||||
|
'--port', str(args_pytest["port"]),
|
||||||
|
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||||
|
'--cpu',
|
||||||
|
]
|
||||||
|
p = subprocess.Popen(pargs)
|
||||||
|
yield
|
||||||
|
p.kill()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@fixture(scope="class", autouse=True)
|
||||||
|
def shared_client(self, args_pytest, _server):
|
||||||
|
"""Create shared client with connection retry."""
|
||||||
|
client = ComfyClient()
|
||||||
|
n_tries = 5
|
||||||
|
for i in range(n_tries):
|
||||||
|
time.sleep(4)
|
||||||
|
try:
|
||||||
|
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
if i == n_tries - 1:
|
||||||
|
raise
|
||||||
|
yield client
|
||||||
|
del client
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def client(self, shared_client, request):
|
||||||
|
"""Set test name for each test."""
|
||||||
|
shared_client.set_test_name(f"public_api[{request.node.name}]")
|
||||||
|
yield shared_client
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def builder(self, request):
|
||||||
|
"""Create GraphBuilder for each test."""
|
||||||
|
yield GraphBuilder(prefix=request.node.name)
|
||||||
|
|
||||||
|
def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that TestSyncProgressUpdate executes without errors.
|
||||||
|
|
||||||
|
This test validates that api_sync.execution.set_progress() works correctly,
|
||||||
|
which is the primary code path fixed by adding get_type_hints() to async_to_sync.py.
|
||||||
|
"""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||||
|
|
||||||
|
# Use TestSyncProgressUpdate with short sleep
|
||||||
|
progress_node = g.node("TestSyncProgressUpdate",
|
||||||
|
value=image.out(0),
|
||||||
|
sleep_seconds=0.5)
|
||||||
|
output = g.node("SaveImage", images=progress_node.out(0))
|
||||||
|
|
||||||
|
# Execute workflow
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Verify execution
|
||||||
|
assert result.did_run(progress_node), "Progress node should have executed"
|
||||||
|
assert result.did_run(output), "Output node should have executed"
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have produced 1 image"
|
||||||
|
|
||||||
|
def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that TestAsyncProgressUpdate executes without errors.
|
||||||
|
|
||||||
|
This test validates that await api.execution.set_progress() works correctly
|
||||||
|
in async contexts.
|
||||||
|
"""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||||
|
|
||||||
|
# Use TestAsyncProgressUpdate with short sleep
|
||||||
|
progress_node = g.node("TestAsyncProgressUpdate",
|
||||||
|
value=image.out(0),
|
||||||
|
sleep_seconds=0.5)
|
||||||
|
output = g.node("SaveImage", images=progress_node.out(0))
|
||||||
|
|
||||||
|
# Execute workflow
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Verify execution
|
||||||
|
assert result.did_run(progress_node), "Async progress node should have executed"
|
||||||
|
assert result.did_run(output), "Output node should have executed"
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have produced 1 image"
|
||||||
|
|
||||||
|
def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test both sync and async progress updates in same workflow.
|
||||||
|
|
||||||
|
This test ensures that both ComfyAPISync and ComfyAPI can coexist and work
|
||||||
|
correctly in the same workflow execution.
|
||||||
|
"""
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||||
|
|
||||||
|
# Use both types of progress nodes
|
||||||
|
sync_progress = g.node("TestSyncProgressUpdate",
|
||||||
|
value=image1.out(0),
|
||||||
|
sleep_seconds=0.3)
|
||||||
|
async_progress = g.node("TestAsyncProgressUpdate",
|
||||||
|
value=image2.out(0),
|
||||||
|
sleep_seconds=0.3)
|
||||||
|
|
||||||
|
# Create outputs
|
||||||
|
output1 = g.node("SaveImage", images=sync_progress.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=async_progress.out(0))
|
||||||
|
|
||||||
|
# Execute workflow
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Both should execute successfully
|
||||||
|
assert result.did_run(sync_progress), "Sync progress node should have executed"
|
||||||
|
assert result.did_run(async_progress), "Async progress node should have executed"
|
||||||
|
assert result.did_run(output1), "First output node should have executed"
|
||||||
|
assert result.did_run(output2), "Second output node should have executed"
|
||||||
|
|
||||||
|
# Verify outputs
|
||||||
|
images1 = result.get_images(output1)
|
||||||
|
images2 = result.get_images(output2)
|
||||||
|
assert len(images1) == 1, "Should have produced 1 image from sync node"
|
||||||
|
assert len(images2) == 1, "Should have produced 1 image from async node"
|
||||||
Loading…
x
Reference in New Issue
Block a user