mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645)
This commit is contained in:
parent
c4a6b389de
commit
bda0eb2448
@ -1,17 +0,0 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import PixverseDto
|
||||
|
||||
|
||||
class ResponseData(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
||||
@ -1,57 +0,0 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class V2OpenAPII2VResp(BaseModel):
|
||||
video_id: Optional[int] = Field(None, description='Video_id')
|
||||
|
||||
|
||||
class V2OpenAPIT2VReq(BaseModel):
|
||||
aspect_ratio: str = Field(
|
||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||
)
|
||||
duration: int = Field(
|
||||
...,
|
||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||
examples=[5],
|
||||
)
|
||||
model: str = Field(
|
||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||
)
|
||||
motion_mode: Optional[str] = Field(
|
||||
'normal',
|
||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||
examples=['normal'],
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative prompt\n', max_length=2048
|
||||
)
|
||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
||||
quality: str = Field(
|
||||
...,
|
||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||
examples=['540p'],
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||
style: Optional[str] = Field(
|
||||
None,
|
||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||
examples=['anime'],
|
||||
)
|
||||
template_id: Optional[int] = Field(
|
||||
None,
|
||||
description='Template ID (template_id must be activated before use)',
|
||||
examples=[302325299692608],
|
||||
)
|
||||
water_mark: Optional[bool] = Field(
|
||||
False,
|
||||
description='Watermark (true: add watermark, false: no watermark)',
|
||||
examples=[False],
|
||||
)
|
||||
@ -1,981 +0,0 @@
|
||||
"""
|
||||
API Client Framework for api.comfy.org.
|
||||
|
||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||
|
||||
Key Components:
|
||||
--------------
|
||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||
3. ApiOperation - Executes a single synchronous API operation
|
||||
|
||||
Usage Examples:
|
||||
--------------
|
||||
|
||||
# Example 1: Synchronous API Operation
|
||||
# ------------------------------------
|
||||
# For a simple API call that returns the result immediately:
|
||||
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
auth_token="your_auth_token_here",
|
||||
comfy_api_key="your_comfy_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
|
||||
# 2. Define the endpoint
|
||||
user_info_endpoint = ApiEndpoint(
|
||||
path="/v1/users/me",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest, # No request body needed
|
||||
response_model=UserProfile, # Pydantic model for the response
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 3. Create the request object
|
||||
request = EmptyRequest()
|
||||
|
||||
# 4. Create and execute the operation
|
||||
operation = ApiOperation(
|
||||
endpoint=user_info_endpoint,
|
||||
request=request
|
||||
)
|
||||
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
|
||||
|
||||
|
||||
# Example 2: Asynchronous API Operation with Polling
|
||||
# -------------------------------------------------
|
||||
# For an API that starts a task and requires polling for completion:
|
||||
|
||||
# 1. Define the endpoints (initial request and polling)
|
||||
generate_image_endpoint = ApiEndpoint(
|
||||
path="/v1/images/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=ImageGenerationRequest,
|
||||
response_model=TaskCreatedResponse,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
check_task_endpoint = ApiEndpoint(
|
||||
path="/v1/tasks/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=ImageGenerationResult,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 2. Create the request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt="a beautiful sunset over mountains",
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 3. Create and execute the polling operation
|
||||
operation = PollingOperation(
|
||||
initial_endpoint=generate_image_endpoint,
|
||||
initial_request=request,
|
||||
poll_endpoint=check_task_endpoint,
|
||||
task_id_field="task_id",
|
||||
status_field="status",
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed", "error"]
|
||||
)
|
||||
|
||||
# This will make the initial request and then poll until completion
|
||||
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
||||
from typing import Type, Optional, Any, TypeVar, Generic, Callable
|
||||
from enum import Enum
|
||||
import json
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid # For generating unique operation IDs
|
||||
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
from . import request_logger
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
P = TypeVar("P", bound=BaseModel) # For poll response
|
||||
|
||||
PROGRESS_BAR_MAX = 100
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
pass
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
pass
|
||||
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
class HttpMethod(str, Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
timeout: float = 3600.0,
|
||||
verify_ssl: bool = True,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
retry_status_codes: Optional[tuple[int, ...]] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
||||
# 500, 502, 503, 504 (Server Errors)
|
||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
||||
self._session: Optional[aiohttp.ClientSession] = session
|
||||
self._owns_session = session is None # Track if we have to close it
|
||||
|
||||
@staticmethod
|
||||
def _generate_operation_id(path: str) -> str:
|
||||
"""Generates a unique operation ID for logging."""
|
||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@staticmethod
|
||||
def _create_json_payload_args(
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"json": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def _create_form_data_args(
|
||||
self,
|
||||
data: dict[str, Any] | None,
|
||||
files: dict[str, Any] | None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if headers and "Content-Type" in headers:
|
||||
del headers["Content-Type"]
|
||||
|
||||
if multipart_parser and data:
|
||||
data = multipart_parser(data)
|
||||
|
||||
if isinstance(data, aiohttp.FormData):
|
||||
form = data # If the parser already returned a FormData, pass it through
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if data: # regular text fields
|
||||
for k, v in data.items():
|
||||
if v is None:
|
||||
continue # aiohttp fails to serialize "None" values
|
||||
# aiohttp expects strings or bytes; convert enums etc.
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
|
||||
if files:
|
||||
file_iter = files if isinstance(files, list) else files.items()
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue # aiohttp fails to serialize "None" values
|
||||
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
|
||||
if isinstance(file_obj, tuple):
|
||||
filename, file_value, content_type = self._unpack_tuple(file_obj)
|
||||
else:
|
||||
file_value = file_obj
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
form.add_field(
|
||||
name=field_name,
|
||||
value=file_value,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
return {"data": form, "headers": headers or {}}
|
||||
|
||||
@staticmethod
|
||||
def _create_urlencoded_form_data_args(
|
||||
data: dict[str, Any],
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
headers = headers or {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
return {
|
||||
"data": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
elif self.comfy_api_key:
|
||||
headers["X-API-KEY"] = self.comfy_api_key
|
||||
|
||||
return headers
|
||||
|
||||
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
|
||||
"""
|
||||
Check connectivity to determine if network issues are local or server-related.
|
||||
|
||||
Args:
|
||||
target_url: URL to check connectivity to
|
||||
|
||||
Returns:
|
||||
Dictionary with connectivity status details
|
||||
"""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
"is_local_issue": False,
|
||||
"is_api_issue": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
try:
|
||||
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
except (ClientError, asyncio.TimeoutError, socket.gaierror):
|
||||
results["is_local_issue"] = True
|
||||
return results # cannot reach the internet – early exit
|
||||
|
||||
# Now check API health endpoint
|
||||
parsed = urlparse(target_url)
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
try:
|
||||
async with session.get(health_url, ssl=self.verify_ssl) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
except ClientError:
|
||||
pass # leave as False
|
||||
|
||||
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
|
||||
return results
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
retry_count: int = 0, # Used internally for tracking retries
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API with automatic retries for transient errors.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: API endpoint path (will be joined with base_url)
|
||||
params: Query parameters
|
||||
data: body data
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
content_type: Content type of the request. Defaults to application/json.
|
||||
retry_count: Internal parameter for tracking retries, do not set manually
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
LocalNetworkError: If local network connectivity issues are detected
|
||||
ApiServerError: If the API server is unreachable but internet is working
|
||||
Exception: For other request failures
|
||||
"""
|
||||
|
||||
# Build full URL and merge headers
|
||||
relative_path = path.lstrip("/")
|
||||
url = urljoin(self.base_url, relative_path)
|
||||
self._check_auth(self.auth_token, self.comfy_api_key)
|
||||
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
if files:
|
||||
request_headers.pop("Content-Type", None)
|
||||
if params:
|
||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||||
|
||||
logging.debug("[DEBUG] Request Headers: %s", request_headers)
|
||||
logging.debug("[DEBUG] Files: %s", files)
|
||||
logging.debug("[DEBUG] Params: %s", params)
|
||||
logging.debug("[DEBUG] Data: %s", data)
|
||||
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||||
elif content_type == "multipart/form-data":
|
||||
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
|
||||
else:
|
||||
payload_args = self._create_json_payload_args(data, request_headers)
|
||||
|
||||
operation_id = self._generate_operation_id(path)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=request_headers,
|
||||
request_params=params,
|
||||
request_data=data if content_type == "application/json" else "[form-data or other]",
|
||||
)
|
||||
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
ssl=self.verify_ssl,
|
||||
**payload_args,
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
error_data = await resp.json()
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
error_data = await resp.text()
|
||||
|
||||
return await self._handle_http_error(
|
||||
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
|
||||
operation_id,
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
data,
|
||||
files,
|
||||
headers,
|
||||
content_type,
|
||||
multipart_parser,
|
||||
retry_count=retry_count,
|
||||
response_content=error_data,
|
||||
)
|
||||
|
||||
# Success – parse JSON (safely) and log
|
||||
try:
|
||||
payload = await resp.json()
|
||||
response_content_to_log = payload
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
payload = {}
|
||||
response_content_to_log = await resp.text()
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
return payload
|
||||
|
||||
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
|
||||
# Treat as *connection* problem – optionally retry, else escalate
|
||||
if retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
|
||||
self.max_retries, str(e))
|
||||
await asyncio.sleep(delay)
|
||||
return await self.request(
|
||||
method,
|
||||
path,
|
||||
params=params,
|
||||
data=data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
content_type=content_type,
|
||||
multipart_parser=multipart_parser,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
# One final connectivity check for diagnostics
|
||||
connectivity = await self._check_connectivity(self.base_url)
|
||||
if connectivity["is_local_issue"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError(
|
||||
f"The API server at {self.base_url} is currently unreachable. "
|
||||
f"The service may be experiencing issues. Please try again later."
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _check_auth(auth_token, comfy_api_key):
|
||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||||
if auth_token is None and comfy_api_key is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token or comfy_api_key
|
||||
|
||||
@staticmethod
|
||||
async def upload_file(
|
||||
upload_url: str,
|
||||
file: io.BytesIO | str,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""Upload a file to the API with retry logic.
|
||||
|
||||
Args:
|
||||
upload_url: The URL to upload to
|
||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||
content_type: Optional mime type to set for the upload
|
||||
max_retries: Maximum number of retry attempts
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
retry_backoff_factor: Multiplier for the delay after each retry
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
skip_auto_headers: set[str] = set()
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
|
||||
skip_auto_headers.add("Content-Type")
|
||||
|
||||
# Extract file bytes
|
||||
if isinstance(file, io.BytesIO):
|
||||
file.seek(0)
|
||||
data = file.read()
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
raise ValueError("File must be BytesIO or str path")
|
||||
|
||||
parsed = urlparse(upload_url)
|
||||
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
|
||||
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
|
||||
delay = retry_delay
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.put(
|
||||
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
return resp
|
||||
except (ClientError, asyncio.TimeoutError) as e:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=e.status if hasattr(e, "status") else None,
|
||||
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
|
||||
response_content=None,
|
||||
error_message=f"{type(e).__name__}: {str(e)}",
|
||||
)
|
||||
if attempt < max_retries:
|
||||
logging.warning(
|
||||
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay *= retry_backoff_factor
|
||||
else:
|
||||
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
|
||||
|
||||
async def _handle_http_error(
|
||||
self,
|
||||
exc: ClientResponseError,
|
||||
operation_id: str,
|
||||
*req_meta,
|
||||
retry_count: int,
|
||||
response_content: dict | str = "",
|
||||
) -> dict[str, Any]:
|
||||
status_code = exc.status
|
||||
if status_code == 401:
|
||||
user_friendly = "Unauthorized: Please login first to use this node."
|
||||
elif status_code == 402:
|
||||
user_friendly = "Payment Required: Please add credits to your account to use this node."
|
||||
elif status_code == 409:
|
||||
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
|
||||
elif status_code == 429:
|
||||
user_friendly = "Rate Limit Exceeded: Please try again later."
|
||||
else:
|
||||
if isinstance(response_content, dict):
|
||||
if "error" in response_content and "message" in response_content["error"]:
|
||||
user_friendly = f"API Error: {response_content['error']['message']}"
|
||||
if "type" in response_content["error"]:
|
||||
user_friendly += f" (Type: {response_content['error']['type']})"
|
||||
else: # Handle cases where error is just a JSON dict with unknown format
|
||||
user_friendly = f"API Error: {json.dumps(response_content)}"
|
||||
else:
|
||||
if len(response_content) < 200: # Arbitrary limit for display
|
||||
user_friendly = f"API Error (raw): {response_content}"
|
||||
else:
|
||||
user_friendly = f"API Error (raw, status {response_content})"
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=req_meta[0],
|
||||
request_url=req_meta[1],
|
||||
response_status_code=exc.status,
|
||||
response_headers=dict(req_meta[5]) if req_meta[5] else None,
|
||||
response_content=response_content,
|
||||
error_message=f"HTTP Error {exc.status}",
|
||||
)
|
||||
|
||||
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
|
||||
if response_content:
|
||||
logging.debug("[DEBUG] Response content: %s", response_content)
|
||||
|
||||
# Retry if eligible
|
||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||||
logging.warning(
|
||||
"HTTP error %s. Retrying in %.2fs (%s/%s)",
|
||||
status_code,
|
||||
delay,
|
||||
retry_count + 1,
|
||||
self.max_retries,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
return await self.request(
|
||||
req_meta[0], # method
|
||||
req_meta[1].replace(self.base_url, ""), # path
|
||||
params=req_meta[2],
|
||||
data=req_meta[3],
|
||||
files=req_meta[4],
|
||||
headers=req_meta[5],
|
||||
content_type=req_meta[6],
|
||||
multipart_parser=req_meta[7],
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
raise Exception(user_friendly) from exc
|
||||
|
||||
@staticmethod
|
||||
def _unpack_tuple(t):
|
||||
"""Helper to normalise (filename, file, content_type) tuples."""
|
||||
if len(t) == 3:
|
||||
return t
|
||||
elif len(t) == 2:
|
||||
return t[0], t[1], "application/octet-stream"
|
||||
else:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
self._owns_session = True
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_session and self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self) -> "ApiClient":
|
||||
"""Allow usage as async‑context‑manager – ensures clean teardown"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
"""Defines an API endpoint with its request and response types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: HttpMethod,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize an API endpoint definition.
|
||||
|
||||
Args:
|
||||
path: The URL path for this endpoint, can include placeholders like {id}
|
||||
method: The HTTP method to use (GET, POST, etc.)
|
||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||
query_params: Optional dictionary of query parameters to include in the request
|
||||
"""
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.request_model = request_model
|
||||
self.response_model = response_model
|
||||
self.query_params = query_params or {}
|
||||
|
||||
|
||||
class SynchronousOperation(Generic[T, R]):
|
||||
"""Represents a single synchronous API operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: ApiEndpoint[T, R],
|
||||
request: T,
|
||||
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
timeout: float = 7200.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
self.files = files
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.content_type = content_type
|
||||
self.multipart_parser = multipart_parser
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
|
||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
request_dict: Optional[dict[str, Any]]
|
||||
if isinstance(self.request, EmptyRequest):
|
||||
request_dict = None
|
||||
else:
|
||||
request_dict = self.request.model_dump(exclude_none=True)
|
||||
for k, v in list(request_dict.items()):
|
||||
if isinstance(v, Enum):
|
||||
request_dict[k] = v.value
|
||||
|
||||
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
|
||||
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
|
||||
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
|
||||
|
||||
response_json = await client.request(
|
||||
self.endpoint.method.value,
|
||||
self.endpoint.path,
|
||||
params=self.endpoint.query_params,
|
||||
data=request_dict,
|
||||
files=self.files,
|
||||
content_type=self.content_type,
|
||||
multipart_parser=self.multipart_parser,
|
||||
)
|
||||
|
||||
logging.debug("=" * 50)
|
||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
|
||||
logging.debug("=" * 50)
|
||||
|
||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||||
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
|
||||
return parsed_response
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.close()
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Enum for task status values"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class PollingOperation(Generic[T, R]):
|
||||
"""Represents an asynchronous API operation that requires polling for completion."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||
completed_statuses: list[str],
|
||||
failed_statuses: list[str],
|
||||
*,
|
||||
status_extractor: Callable[[R], Optional[str]],
|
||||
progress_extractor: Callable[[R], Optional[float]] | None = None,
|
||||
result_url_extractor: Callable[[R], Optional[str]] | None = None,
|
||||
price_extractor: Callable[[R], Optional[float]] | None = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||
max_retries: int = 3, # Max retries per individual API call
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
estimated_duration: Optional[float] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.poll_interval = poll_interval
|
||||
self.max_poll_attempts = max_poll_attempts
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.estimated_duration = estimated_duration
|
||||
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
|
||||
self.progress_extractor = progress_extractor
|
||||
self.result_url_extractor = result_url_extractor
|
||||
self.price_extractor = price_extractor
|
||||
self.node_id = node_id
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
self.final_response: Optional[R] = None
|
||||
self.extracted_price: Optional[float] = None
|
||||
|
||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
)
|
||||
try:
|
||||
return await self._poll_until_complete(client)
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.close()
|
||||
|
||||
def _display_text_on_node(self, text: str):
|
||||
if not self.node_id:
|
||||
return
|
||||
if self.extracted_price is not None:
|
||||
text = f"Price: ${self.extracted_price}\n{text}"
|
||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||||
|
||||
def _display_time_progress_on_node(self, time_completed: int | float):
|
||||
if not self.node_id:
|
||||
return
|
||||
if self.estimated_duration is not None:
|
||||
remaining = max(0, int(self.estimated_duration) - time_completed)
|
||||
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
|
||||
else:
|
||||
message = f"Task in progress: {time_completed}s"
|
||||
self._display_text_on_node(message)
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
try:
|
||||
status = self.status_extractor(response)
|
||||
if status in self.completed_statuses:
|
||||
return TaskStatus.COMPLETED
|
||||
if status in self.failed_statuses:
|
||||
return TaskStatus.FAILED
|
||||
return TaskStatus.PENDING
|
||||
except Exception as e:
|
||||
logging.error("Error extracting status: %s", e)
|
||||
return TaskStatus.PENDING
|
||||
|
||||
async def _poll_until_complete(self, client: ApiClient) -> R:
|
||||
"""Poll until the task is complete"""
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
||||
|
||||
if self.progress_extractor:
|
||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||
|
||||
status = TaskStatus.PENDING
|
||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
||||
try:
|
||||
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
|
||||
|
||||
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
|
||||
|
||||
if poll_count == 1:
|
||||
logging.debug(
|
||||
"[DEBUG] Poll Request: %s %s",
|
||||
self.poll_endpoint.method.value,
|
||||
self.poll_endpoint.path,
|
||||
)
|
||||
logging.debug(
|
||||
"[DEBUG] Poll Request Data: %s",
|
||||
json.dumps(request_dict, indent=2) if request_dict else "None",
|
||||
)
|
||||
|
||||
# Query task status
|
||||
resp = await client.request(
|
||||
self.poll_endpoint.method.value,
|
||||
self.poll_endpoint.path,
|
||||
params=self.poll_endpoint.query_params,
|
||||
data=request_dict,
|
||||
)
|
||||
consecutive_errors = 0 # reset on success
|
||||
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
|
||||
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug("[DEBUG] Task Status: %s", status)
|
||||
|
||||
# If progress extractor is provided, extract progress
|
||||
if self.progress_extractor:
|
||||
new_progress = self.progress_extractor(response_obj)
|
||||
if new_progress is not None:
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if self.price_extractor:
|
||||
price = self.price_extractor(response_obj)
|
||||
if price is not None:
|
||||
self.extracted_price = price
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
message = "Task completed successfully"
|
||||
if self.result_url_extractor:
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
logging.debug("[DEBUG] %s", message)
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
return self.final_response
|
||||
if status == TaskStatus.FAILED:
|
||||
message = f"Task failed: {json.dumps(resp)}"
|
||||
logging.error("[DEBUG] %s", message)
|
||||
raise Exception(message)
|
||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||
# Task pending – wait
|
||||
for i in range(int(self.poll_interval)):
|
||||
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except (LocalNetworkError, ApiServerError, NetworkError) as e:
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||||
) from e
|
||||
logging.warning(
|
||||
"Network error (%s/%s): %s",
|
||||
consecutive_errors,
|
||||
max_consecutive_errors,
|
||||
str(e),
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
except Exception as e:
|
||||
# For other errors, increment count and potentially abort
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
||||
raise Exception(
|
||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||
) from e
|
||||
|
||||
logging.error("[DEBUG] Polling error: %s", str(e))
|
||||
logging.warning(
|
||||
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
|
||||
poll_count,
|
||||
self.max_poll_attempts,
|
||||
str(e),
|
||||
self.poll_interval,
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
# If we've exhausted all polling attempts
|
||||
raise Exception(
|
||||
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
|
||||
"The operation may still be running on the server but is taking longer than expected."
|
||||
)
|
||||
@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
import folder_paths as comfy_paths
|
||||
import aiohttp
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
@ -26,11 +23,11 @@ from comfy_api_nodes.apis.rodin_api import (
|
||||
Rodin3DDownloadResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
download_url_to_bytesio,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
||||
|
||||
|
||||
async def create_generate_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
images=None,
|
||||
seed=1,
|
||||
material="PBR",
|
||||
quality_override=18000,
|
||||
tier="Regular",
|
||||
mesh_mode="Quad",
|
||||
TAPose = False,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
ta_pose: bool = False,
|
||||
):
|
||||
if images is None:
|
||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||
if len(images) > 5:
|
||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||
|
||||
path = "/proxy/rodin/api/v2/rodin"
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DGenerateRequest,
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
),
|
||||
request=Rodin3DGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||
response_model=Rodin3DGenerateResponse,
|
||||
data=Rodin3DGenerateRequest(
|
||||
seed=seed,
|
||||
tier=tier,
|
||||
material=material,
|
||||
quality_override=quality_override,
|
||||
mesh_mode=mesh_mode,
|
||||
TAPose=TAPose,
|
||||
TAPose=ta_pose,
|
||||
),
|
||||
files=[
|
||||
(
|
||||
@ -159,11 +152,8 @@ async def create_generate_task(
|
||||
for image in images if image is not None
|
||||
],
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if hasattr(response, "error"):
|
||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||
logging.error(error_message)
|
||||
@ -187,74 +177,46 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||
if not response.jobs:
|
||||
return None
|
||||
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||
return int((completed_count / len(response.jobs)) * 100)
|
||||
|
||||
async def poll_for_task_status(
|
||||
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> Rodin3DCheckStatusResponse:
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/rodin/api/v2/status",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DCheckStatusRequest,
|
||||
response_model=Rodin3DCheckStatusResponse,
|
||||
),
|
||||
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
||||
completed_statuses=["DONE"],
|
||||
failed_statuses=["FAILED"],
|
||||
status_extractor=check_rodin_status,
|
||||
poll_interval=3.0,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
|
||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||
return await poll_operation.execute()
|
||||
|
||||
|
||||
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/rodin/api/v2/download",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Rodin3DDownloadRequest,
|
||||
response_model=Rodin3DDownloadResponse,
|
||||
),
|
||||
request=Rodin3DDownloadRequest(task_uuid=uuid),
|
||||
auth_kwargs=auth_kwargs,
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
|
||||
response_model=Rodin3DCheckStatusResponse,
|
||||
data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
||||
status_extractor=check_rodin_status,
|
||||
progress_extractor=extract_progress,
|
||||
)
|
||||
return await operation.execute()
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid):
|
||||
async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
|
||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||
return await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
|
||||
response_model=Rodin3DDownloadResponse,
|
||||
data=Rodin3DDownloadRequest(task_uuid=uuid),
|
||||
monitor_progress=False,
|
||||
)
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid: str):
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(i.url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
f.write(chunk)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
|
||||
if attempt < max_retries - 1:
|
||||
logging.info("Retrying...")
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logging.info(
|
||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
||||
file_path,
|
||||
max_retries,
|
||||
)
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
return model_file_path
|
||||
|
||||
|
||||
@ -276,6 +238,7 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@ -294,21 +257,17 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
@ -332,6 +291,7 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@ -350,21 +310,17 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
@ -388,6 +344,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@ -400,27 +357,22 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
Material_Type,
|
||||
Polygon_count,
|
||||
) -> IO.NodeOutput:
|
||||
tier = "Smooth"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
tier="Smooth",
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
@ -451,6 +403,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@ -461,29 +414,21 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
Images,
|
||||
Seed,
|
||||
) -> IO.NodeOutput:
|
||||
tier = "Sketch"
|
||||
num_images = Images.shape[0]
|
||||
m_images = []
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
material_type = "PBR"
|
||||
quality_override = 18000
|
||||
mesh_mode = "Quad"
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=material_type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
auth_kwargs=auth,
|
||||
material="PBR",
|
||||
quality_override=18000,
|
||||
tier="Sketch",
|
||||
mesh_mode="Quad",
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
@ -522,6 +467,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
@ -541,22 +487,18 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
for i in range(num_images):
|
||||
m_images.append(Images[i])
|
||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
task_uuid, subscription_key = await create_generate_task(
|
||||
cls,
|
||||
images=m_images,
|
||||
seed=Seed,
|
||||
material=Material_Type,
|
||||
quality_override=quality_override,
|
||||
tier=tier,
|
||||
mesh_mode=mesh_mode,
|
||||
TAPose=TAPose,
|
||||
auth_kwargs=auth,
|
||||
ta_pose=TAPose,
|
||||
)
|
||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
@ -16,9 +16,9 @@ from pydantic import BaseModel
|
||||
|
||||
from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
from server import PromptServer
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
@ -77,7 +77,7 @@ class _PollUIState:
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||
|
||||
|
||||
@ -12,8 +12,8 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
import folder_paths
|
||||
@ -13,8 +13,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import IO, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user