mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
[fix] Fixes non-async public API access (#10857)
It looks like the synchronous version of the public API broke due to an addition of `from __future__ import annotations`. This change updates the async-to-sync adapter to work with both types of type annotations.
This commit is contained in:
parent
cbd68e3d58
commit
f66183a541
@ -8,7 +8,7 @@ import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
@ -220,7 +220,14 @@ class AsyncToSyncConverter:
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
# Get all annotations from the class hierarchy and resolve string annotations
|
||||
try:
|
||||
# get_type_hints resolves string annotations to actual type objects
|
||||
# This handles classes using 'from __future__ import annotations'
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations if get_type_hints fails
|
||||
# (e.g., for undefined forward references)
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Get resolved type hints to handle string annotations
|
||||
try:
|
||||
type_hints = get_type_hints(async_class)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
elif name in type_hints:
|
||||
# Use resolved type hint instead of raw annotation
|
||||
annotation = type_hints[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
@ -908,7 +919,11 @@ class AsyncToSyncConverter:
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
# Resolve string annotations to actual types
|
||||
try:
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
|
||||
153
tests/execution/test_public_api.py
Normal file
153
tests/execution/test_public_api.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""
|
||||
Tests for public ComfyAPI and ComfyAPISync functions.
|
||||
|
||||
These tests verify that the public API methods work correctly in both sync and async contexts,
|
||||
ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly
|
||||
handles string annotations from 'from __future__ import annotations'.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import subprocess
|
||||
import torch
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import ComfyClient
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestPublicAPI:
|
||||
"""Test suite for public ComfyAPI and ComfyAPISync methods."""
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
"""Start ComfyUI server for testing."""
|
||||
pargs = [
|
||||
'python', 'main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
"""Create shared client with connection retry."""
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
if i == n_tries - 1:
|
||||
raise
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
"""Set test name for each test."""
|
||||
shared_client.set_test_name(f"public_api[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
"""Create GraphBuilder for each test."""
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that TestSyncProgressUpdate executes without errors.
|
||||
|
||||
This test validates that api_sync.execution.set_progress() works correctly,
|
||||
which is the primary code path fixed by adding get_type_hints() to async_to_sync.py.
|
||||
"""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||
|
||||
# Use TestSyncProgressUpdate with short sleep
|
||||
progress_node = g.node("TestSyncProgressUpdate",
|
||||
value=image.out(0),
|
||||
sleep_seconds=0.5)
|
||||
output = g.node("SaveImage", images=progress_node.out(0))
|
||||
|
||||
# Execute workflow
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution
|
||||
assert result.did_run(progress_node), "Progress node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify output
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have produced 1 image"
|
||||
|
||||
def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that TestAsyncProgressUpdate executes without errors.
|
||||
|
||||
This test validates that await api.execution.set_progress() works correctly
|
||||
in async contexts.
|
||||
"""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||
|
||||
# Use TestAsyncProgressUpdate with short sleep
|
||||
progress_node = g.node("TestAsyncProgressUpdate",
|
||||
value=image.out(0),
|
||||
sleep_seconds=0.5)
|
||||
output = g.node("SaveImage", images=progress_node.out(0))
|
||||
|
||||
# Execute workflow
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution
|
||||
assert result.did_run(progress_node), "Async progress node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify output
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have produced 1 image"
|
||||
|
||||
def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test both sync and async progress updates in same workflow.
|
||||
|
||||
This test ensures that both ComfyAPISync and ComfyAPI can coexist and work
|
||||
correctly in the same workflow execution.
|
||||
"""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||
|
||||
# Use both types of progress nodes
|
||||
sync_progress = g.node("TestSyncProgressUpdate",
|
||||
value=image1.out(0),
|
||||
sleep_seconds=0.3)
|
||||
async_progress = g.node("TestAsyncProgressUpdate",
|
||||
value=image2.out(0),
|
||||
sleep_seconds=0.3)
|
||||
|
||||
# Create outputs
|
||||
output1 = g.node("SaveImage", images=sync_progress.out(0))
|
||||
output2 = g.node("SaveImage", images=async_progress.out(0))
|
||||
|
||||
# Execute workflow
|
||||
result = client.run(g)
|
||||
|
||||
# Both should execute successfully
|
||||
assert result.did_run(sync_progress), "Sync progress node should have executed"
|
||||
assert result.did_run(async_progress), "Async progress node should have executed"
|
||||
assert result.did_run(output1), "First output node should have executed"
|
||||
assert result.did_run(output2), "Second output node should have executed"
|
||||
|
||||
# Verify outputs
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have produced 1 image from sync node"
|
||||
assert len(images2) == 1, "Should have produced 1 image from async node"
|
||||
Loading…
x
Reference in New Issue
Block a user