diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py deleted file mode 100644 index 310c0f546..000000000 --- a/comfy_api_nodes/apis/PixverseController.py +++ /dev/null @@ -1,17 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -from . import PixverseDto - - -class ResponseData(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py deleted file mode 100644 index 323c38e96..000000000 --- a/comfy_api_nodes/apis/PixverseDto.py +++ /dev/null @@ -1,57 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel, Field - - -class V2OpenAPII2VResp(BaseModel): - video_id: Optional[int] = Field(None, description='Video_id') - - -class V2OpenAPIT2VReq(BaseModel): - aspect_ratio: str = Field( - ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9'] - ) - duration: int = Field( - ..., - description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)', - examples=[5], - ) - model: str = Field( - ..., description='Model version (only supports v3.5)', examples=['v3.5'] - ) - motion_mode: Optional[str] = Field( - 'normal', - description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', - examples=['normal'], - ) - negative_prompt: Optional[str] = Field( - None, description='Negative prompt\n', max_length=2048 - ) - prompt: str = Field(..., description='Prompt', max_length=2048) - quality: str = Field( - ..., - description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', - examples=['540p'], - ) - seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647') - style: Optional[str] = Field( - None, - description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed', - examples=['anime'], - ) - template_id: Optional[int] = Field( - None, - description='Template ID (template_id must be activated before use)', - examples=[302325299692608], - ) - water_mark: Optional[bool] = Field( - False, - description='Watermark (true: add watermark, false: no watermark)', - examples=[False], - ) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py deleted file mode 100644 index bdaddcc88..000000000 --- a/comfy_api_nodes/apis/client.py +++ /dev/null @@ -1,981 +0,0 @@ -""" -API Client Framework for api.comfy.org. - -This module provides a flexible framework for making API requests from ComfyUI nodes. -It supports both synchronous and asynchronous API operations with proper type validation. - -Key Components: --------------- -1. ApiClient - Handles HTTP requests with authentication and error handling -2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models -3. ApiOperation - Executes a single synchronous API operation - -Usage Examples: --------------- - -# Example 1: Synchronous API Operation -# ------------------------------------ -# For a simple API call that returns the result immediately: - -# 1. Create the API client -api_client = ApiClient( - base_url="https://api.example.com", - auth_token="your_auth_token_here", - comfy_api_key="your_comfy_api_key_here", - timeout=30.0, - verify_ssl=True -) - -# 2. Define the endpoint -user_info_endpoint = ApiEndpoint( - path="/v1/users/me", - method=HttpMethod.GET, - request_model=EmptyRequest, # No request body needed - response_model=UserProfile, # Pydantic model for the response - query_params=None -) - -# 3. Create the request object -request = EmptyRequest() - -# 4. Create and execute the operation -operation = ApiOperation( - endpoint=user_info_endpoint, - request=request -) -user_profile = await operation.execute(client=api_client) # Returns immediately with the result - - -# Example 2: Asynchronous API Operation with Polling -# ------------------------------------------------- -# For an API that starts a task and requires polling for completion: - -# 1. Define the endpoints (initial request and polling) -generate_image_endpoint = ApiEndpoint( - path="/v1/images/generate", - method=HttpMethod.POST, - request_model=ImageGenerationRequest, - response_model=TaskCreatedResponse, - query_params=None -) - -check_task_endpoint = ApiEndpoint( - path="/v1/tasks/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=ImageGenerationResult, - query_params=None -) - -# 2. Create the request object -request = ImageGenerationRequest( - prompt="a beautiful sunset over mountains", - width=1024, - height=1024, - num_images=1 -) - -# 3. Create and execute the polling operation -operation = PollingOperation( - initial_endpoint=generate_image_endpoint, - initial_request=request, - poll_endpoint=check_task_endpoint, - task_id_field="task_id", - status_field="status", - completed_statuses=["completed"], - failed_statuses=["failed", "error"] -) - -# This will make the initial request and then poll until completion -result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done -""" - -from __future__ import annotations -import aiohttp -import asyncio -import logging -import io -import os -import socket -from aiohttp.client_exceptions import ClientError, ClientResponseError -from typing import Type, Optional, Any, TypeVar, Generic, Callable -from enum import Enum -import json -from urllib.parse import urljoin, urlparse -from pydantic import BaseModel, Field -import uuid # For generating unique operation IDs - -from server import PromptServer -from comfy.cli_args import args -from comfy import utils -from . import request_logger - -T = TypeVar("T", bound=BaseModel) -R = TypeVar("R", bound=BaseModel) -P = TypeVar("P", bound=BaseModel) # For poll response - -PROGRESS_BAR_MAX = 100 - - -class NetworkError(Exception): - """Base exception for network-related errors with diagnostic information.""" - pass - - -class LocalNetworkError(NetworkError): - """Exception raised when local network connectivity issues are detected.""" - pass - - -class ApiServerError(NetworkError): - """Exception raised when the API server is unreachable but internet is working.""" - pass - - -class EmptyRequest(BaseModel): - """Base class for empty request bodies. - For GET requests, fields will be sent as query parameters.""" - - pass - - -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) - - -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") - - -class HttpMethod(str, Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - - -class ApiClient: - """ - Client for making HTTP requests to an API with authentication, error handling, and retry logic. - """ - - def __init__( - self, - base_url: str, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - timeout: float = 3600.0, - verify_ssl: bool = True, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - retry_status_codes: Optional[tuple[int, ...]] = None, - session: Optional[aiohttp.ClientSession] = None, - ): - self.base_url = base_url - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - self.timeout = timeout - self.verify_ssl = verify_ssl - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), - # 500, 502, 503, 504 (Server Errors) - self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) - self._session: Optional[aiohttp.ClientSession] = session - self._owns_session = session is None # Track if we have to close it - - @staticmethod - def _generate_operation_id(path: str) -> str: - """Generates a unique operation ID for logging.""" - return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" - - @staticmethod - def _create_json_payload_args( - data: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - return { - "json": data, - "headers": headers, - } - - def _create_form_data_args( - self, - data: dict[str, Any] | None, - files: dict[str, Any] | None, - headers: Optional[dict[str, str]] = None, - multipart_parser: Callable | None = None, - ) -> dict[str, Any]: - if headers and "Content-Type" in headers: - del headers["Content-Type"] - - if multipart_parser and data: - data = multipart_parser(data) - - if isinstance(data, aiohttp.FormData): - form = data # If the parser already returned a FormData, pass it through - else: - form = aiohttp.FormData(default_to_multipart=True) - if data: # regular text fields - for k, v in data.items(): - if v is None: - continue # aiohttp fails to serialize "None" values - # aiohttp expects strings or bytes; convert enums etc. - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) - - if files: - file_iter = files if isinstance(files, list) else files.items() - for field_name, file_obj in file_iter: - if file_obj is None: - continue # aiohttp fails to serialize "None" values - # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple - if isinstance(file_obj, tuple): - filename, file_value, content_type = self._unpack_tuple(file_obj) - else: - file_value = file_obj - filename = getattr(file_obj, "name", field_name) - content_type = "application/octet-stream" - - form.add_field( - name=field_name, - value=file_value, - filename=filename, - content_type=content_type, - ) - return {"data": form, "headers": headers or {}} - - @staticmethod - def _create_urlencoded_form_data_args( - data: dict[str, Any], - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - headers = headers or {} - headers["Content-Type"] = "application/x-www-form-urlencoded" - return { - "data": data, - "headers": headers, - } - - def get_headers(self) -> dict[str, str]: - """Get headers for API requests, including authentication if available""" - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - if self.auth_token: - headers["Authorization"] = f"Bearer {self.auth_token}" - elif self.comfy_api_key: - headers["X-API-KEY"] = self.comfy_api_key - - return headers - - async def _check_connectivity(self, target_url: str) -> dict[str, bool]: - """ - Check connectivity to determine if network issues are local or server-related. - - Args: - target_url: URL to check connectivity to - - Returns: - Dictionary with connectivity status details - """ - results = { - "internet_accessible": False, - "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False, - } - timeout = aiohttp.ClientTimeout(total=5.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - try: - async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: - results["internet_accessible"] = resp.status < 500 - except (ClientError, asyncio.TimeoutError, socket.gaierror): - results["is_local_issue"] = True - return results # cannot reach the internet – early exit - - # Now check API health endpoint - parsed = urlparse(target_url) - health_url = f"{parsed.scheme}://{parsed.netloc}/health" - try: - async with session.get(health_url, ssl=self.verify_ssl) as resp: - results["api_accessible"] = resp.status < 500 - except ClientError: - pass # leave as False - - results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] - return results - - async def request( - self, - method: str, - path: str, - params: Optional[dict[str, Any]] = None, - data: Optional[dict[str, Any]] = None, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - headers: Optional[dict[str, str]] = None, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - retry_count: int = 0, # Used internally for tracking retries - ) -> dict[str, Any]: - """ - Make an HTTP request to the API with automatic retries for transient errors. - - Args: - method: HTTP method (GET, POST, etc.) - path: API endpoint path (will be joined with base_url) - params: Query parameters - data: body data - files: Files to upload - headers: Additional headers - content_type: Content type of the request. Defaults to application/json. - retry_count: Internal parameter for tracking retries, do not set manually - - Returns: - Parsed JSON response - - Raises: - LocalNetworkError: If local network connectivity issues are detected - ApiServerError: If the API server is unreachable but internet is working - Exception: For other request failures - """ - - # Build full URL and merge headers - relative_path = path.lstrip("/") - url = urljoin(self.base_url, relative_path) - self._check_auth(self.auth_token, self.comfy_api_key) - - request_headers = self.get_headers() - if headers: - request_headers.update(headers) - if files: - request_headers.pop("Content-Type", None) - if params: - params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values - - logging.debug("[DEBUG] Request Headers: %s", request_headers) - logging.debug("[DEBUG] Files: %s", files) - logging.debug("[DEBUG] Params: %s", params) - logging.debug("[DEBUG] Data: %s", data) - - if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) - elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) - else: - payload_args = self._create_json_payload_args(data, request_headers) - - operation_id = self._generate_operation_id(path) - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=request_headers, - request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]", - ) - - session = await self._get_session() - try: - async with session.request( - method, - url, - params=params, - ssl=self.verify_ssl, - **payload_args, - ) as resp: - if resp.status >= 400: - try: - error_data = await resp.json() - except (aiohttp.ContentTypeError, json.JSONDecodeError): - error_data = await resp.text() - - return await self._handle_http_error( - ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), - operation_id, - method, - url, - params, - data, - files, - headers, - content_type, - multipart_parser, - retry_count=retry_count, - response_content=error_data, - ) - - # Success – parse JSON (safely) and log - try: - payload = await resp.json() - response_content_to_log = payload - except (aiohttp.ContentTypeError, json.JSONDecodeError): - payload = {} - response_content_to_log = await resp.text() - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=response_content_to_log, - ) - return payload - - except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: - # Treat as *connection* problem – optionally retry, else escalate - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, - self.max_retries, str(e)) - await asyncio.sleep(delay) - return await self.request( - method, - path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - # One final connectivity check for diagnostics - connectivity = await self._check_connectivity(self.base_url) - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - @staticmethod - def _check_auth(auth_token, comfy_api_key): - """Verify that an auth token is present or comfy_api_key is present""" - if auth_token is None and comfy_api_key is None: - raise Exception("Unauthorized: Please login first to use this node.") - return auth_token or comfy_api_key - - @staticmethod - async def upload_file( - upload_url: str, - file: io.BytesIO | str, - content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> aiohttp.ClientResponse: - """Upload a file to the API with retry logic. - - Args: - upload_url: The URL to upload to - file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - content_type: Optional mime type to set for the upload - max_retries: Maximum number of retry attempts - retry_delay: Initial delay between retries in seconds - retry_backoff_factor: Multiplier for the delay after each retry - """ - headers: dict[str, str] = {} - skip_auto_headers: set[str] = set() - if content_type: - headers["Content-Type"] = content_type - else: - # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. - skip_auto_headers.add("Content-Type") - - # Extract file bytes - if isinstance(file, io.BytesIO): - file.seek(0) - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("File must be BytesIO or str path") - - parsed = urlparse(upload_url) - basename = os.path.basename(parsed.path) or parsed.netloc or "upload" - operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}" - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers, - request_data=f"[File data {len(data)} bytes]", - ) - - delay = retry_delay - for attempt in range(max_retries + 1): - try: - timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.put( - upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, - ) as resp: - resp.raise_for_status() - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - return resp - except (ClientError, asyncio.TimeoutError) as e: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if hasattr(e, "headers") else None, - response_content=None, - error_message=f"{type(e).__name__}: {str(e)}", - ) - if attempt < max_retries: - logging.warning( - "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) - ) - await asyncio.sleep(delay) - delay *= retry_backoff_factor - else: - raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - - async def _handle_http_error( - self, - exc: ClientResponseError, - operation_id: str, - *req_meta, - retry_count: int, - response_content: dict | str = "", - ) -> dict[str, Any]: - status_code = exc.status - if status_code == 401: - user_friendly = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_friendly = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_friendly = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_friendly = "Rate Limit Exceeded: Please try again later." - else: - if isinstance(response_content, dict): - if "error" in response_content and "message" in response_content["error"]: - user_friendly = f"API Error: {response_content['error']['message']}" - if "type" in response_content["error"]: - user_friendly += f" (Type: {response_content['error']['type']})" - else: # Handle cases where error is just a JSON dict with unknown format - user_friendly = f"API Error: {json.dumps(response_content)}" - else: - if len(response_content) < 200: # Arbitrary limit for display - user_friendly = f"API Error (raw): {response_content}" - else: - user_friendly = f"API Error (raw, status {response_content})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=req_meta[0], - request_url=req_meta[1], - response_status_code=exc.status, - response_headers=dict(req_meta[5]) if req_meta[5] else None, - response_content=response_content, - error_message=f"HTTP Error {exc.status}", - ) - - logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code) - if response_content: - logging.debug("[DEBUG] Response content: %s", response_content) - - # Retry if eligible - if status_code in self.retry_status_codes and retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - "HTTP error %s. Retrying in %.2fs (%s/%s)", - status_code, - delay, - retry_count + 1, - self.max_retries, - ) - await asyncio.sleep(delay) - return await self.request( - req_meta[0], # method - req_meta[1].replace(self.base_url, ""), # path - params=req_meta[2], - data=req_meta[3], - files=req_meta[4], - headers=req_meta[5], - content_type=req_meta[6], - multipart_parser=req_meta[7], - retry_count=retry_count + 1, - ) - - raise Exception(user_friendly) from exc - - @staticmethod - def _unpack_tuple(t): - """Helper to normalise (filename, file, content_type) tuples.""" - if len(t) == 3: - return t - elif len(t) == 2: - return t[0], t[1], "application/octet-stream" - else: - raise ValueError("files tuple must be (filename, file[, content_type])") - - async def _get_session(self) -> aiohttp.ClientSession: - if self._session is None or self._session.closed: - timeout = aiohttp.ClientTimeout(total=self.timeout) - self._session = aiohttp.ClientSession(timeout=timeout) - self._owns_session = True - return self._session - - async def close(self) -> None: - if self._owns_session and self._session and not self._session.closed: - await self._session.close() - - async def __aenter__(self) -> "ApiClient": - """Allow usage as async‑context‑manager – ensures clean teardown""" - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.close() - - -class ApiEndpoint(Generic[T, R]): - """Defines an API endpoint with its request and response types""" - - def __init__( - self, - path: str, - method: HttpMethod, - request_model: Type[T], - response_model: Type[R], - query_params: Optional[dict[str, Any]] = None, - ): - """Initialize an API endpoint definition. - - Args: - path: The URL path for this endpoint, can include placeholders like {id} - method: The HTTP method to use (GET, POST, etc.) - request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint - response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint - query_params: Optional dictionary of query parameters to include in the request - """ - self.path = path - self.method = method - self.request_model = request_model - self.response_model = response_model - self.query_params = query_params or {} - - -class SynchronousOperation(Generic[T, R]): - """Represents a single synchronous API operation.""" - - def __init__( - self, - endpoint: ApiEndpoint[T, R], - request: T, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - timeout: float = 7200.0, - verify_ssl: bool = True, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> None: - self.endpoint = endpoint - self.request = request - self.files = files - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.timeout = timeout - self.verify_ssl = verify_ssl - self.content_type = content_type - self.multipart_parser = multipart_parser - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - try: - request_dict: Optional[dict[str, Any]] - if isinstance(self.request, EmptyRequest): - request_dict = None - else: - request_dict = self.request.model_dump(exclude_none=True) - for k, v in list(request_dict.items()): - if isinstance(v, Enum): - request_dict[k] = v.value - - logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path) - logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2)) - logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params) - - response_json = await client.request( - self.endpoint.method.value, - self.endpoint.path, - params=self.endpoint.query_params, - data=request_dict, - files=self.files, - content_type=self.content_type, - multipart_parser=self.multipart_parser, - ) - - logging.debug("=" * 50) - logging.debug("[DEBUG] RESPONSE DETAILS:") - logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2)) - logging.debug("=" * 50) - - parsed_response = self.endpoint.response_model.model_validate(response_json) - logging.debug("[DEBUG] Parsed Response: %s", parsed_response) - return parsed_response - finally: - if owns_client: - await client.close() - - -class TaskStatus(str, Enum): - """Enum for task status values""" - - COMPLETED = "completed" - FAILED = "failed" - PENDING = "pending" - - -class PollingOperation(Generic[T, R]): - """Represents an asynchronous API operation that requires polling for completion.""" - - def __init__( - self, - poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list[str], - failed_statuses: list[str], - *, - status_extractor: Callable[[R], Optional[str]], - progress_extractor: Callable[[R], Optional[float]] | None = None, - result_url_extractor: Callable[[R], Optional[str]] | None = None, - price_extractor: Callable[[R], Optional[float]] | None = None, - request: Optional[T] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - poll_interval: float = 5.0, - max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) - max_retries: int = 3, # Max retries per individual API call - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - estimated_duration: Optional[float] = None, - node_id: Optional[str] = None, - ) -> None: - self.poll_endpoint = poll_endpoint - self.request = request - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.poll_interval = poll_interval - self.max_poll_attempts = max_poll_attempts - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - self.estimated_duration = estimated_duration - self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) - self.progress_extractor = progress_extractor - self.result_url_extractor = result_url_extractor - self.price_extractor = price_extractor - self.node_id = node_id - self.completed_statuses = completed_statuses - self.failed_statuses = failed_statuses - self.final_response: Optional[R] = None - self.extracted_price: Optional[float] = None - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - try: - return await self._poll_until_complete(client) - finally: - if owns_client: - await client.close() - - def _display_text_on_node(self, text: str): - if not self.node_id: - return - if self.extracted_price is not None: - text = f"Price: ${self.extracted_price}\n{text}" - PromptServer.instance.send_progress_text(text, self.node_id) - - def _display_time_progress_on_node(self, time_completed: int | float): - if not self.node_id: - return - if self.estimated_duration is not None: - remaining = max(0, int(self.estimated_duration) - time_completed) - message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" - else: - message = f"Task in progress: {time_completed}s" - self._display_text_on_node(message) - - def _check_task_status(self, response: R) -> TaskStatus: - try: - status = self.status_extractor(response) - if status in self.completed_statuses: - return TaskStatus.COMPLETED - if status in self.failed_statuses: - return TaskStatus.FAILED - return TaskStatus.PENDING - except Exception as e: - logging.error("Error extracting status: %s", e) - return TaskStatus.PENDING - - async def _poll_until_complete(self, client: ApiClient) -> R: - """Poll until the task is complete""" - consecutive_errors = 0 - max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors - - if self.progress_extractor: - progress = utils.ProgressBar(PROGRESS_BAR_MAX) - - status = TaskStatus.PENDING - for poll_count in range(1, self.max_poll_attempts + 1): - try: - logging.debug("[DEBUG] Polling attempt #%s", poll_count) - - request_dict = None if self.request is None else self.request.model_dump(exclude_none=True) - - if poll_count == 1: - logging.debug( - "[DEBUG] Poll Request: %s %s", - self.poll_endpoint.method.value, - self.poll_endpoint.path, - ) - logging.debug( - "[DEBUG] Poll Request Data: %s", - json.dumps(request_dict, indent=2) if request_dict else "None", - ) - - # Query task status - resp = await client.request( - self.poll_endpoint.method.value, - self.poll_endpoint.path, - params=self.poll_endpoint.query_params, - data=request_dict, - ) - consecutive_errors = 0 # reset on success - response_obj: R = self.poll_endpoint.response_model.model_validate(resp) - - # Check if task is complete - status = self._check_task_status(response_obj) - logging.debug("[DEBUG] Task Status: %s", status) - - # If progress extractor is provided, extract progress - if self.progress_extractor: - new_progress = self.progress_extractor(response_obj) - if new_progress is not None: - progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) - - if self.price_extractor: - price = self.price_extractor(response_obj) - if price is not None: - self.extracted_price = price - - if status == TaskStatus.COMPLETED: - message = "Task completed successfully" - if self.result_url_extractor: - result_url = self.result_url_extractor(response_obj) - if result_url: - message = f"Result URL: {result_url}" - logging.debug("[DEBUG] %s", message) - self._display_text_on_node(message) - self.final_response = response_obj - if self.progress_extractor: - progress.update(100) - return self.final_response - if status == TaskStatus.FAILED: - message = f"Task failed: {json.dumps(resp)}" - logging.error("[DEBUG] %s", message) - raise Exception(message) - logging.debug("[DEBUG] Task still pending, continuing to poll...") - # Task pending – wait - for i in range(int(self.poll_interval)): - self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) - await asyncio.sleep(1) - - except (LocalNetworkError, ApiServerError, NetworkError) as e: - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - raise Exception( - f"Polling aborted after {consecutive_errors} network errors: {str(e)}" - ) from e - logging.warning( - "Network error (%s/%s): %s", - consecutive_errors, - max_consecutive_errors, - str(e), - ) - await asyncio.sleep(self.poll_interval) - except Exception as e: - # For other errors, increment count and potentially abort - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: - raise Exception( - f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" - ) from e - - logging.error("[DEBUG] Polling error: %s", str(e)) - logging.warning( - "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.", - poll_count, - self.max_poll_attempts, - str(e), - self.poll_interval, - ) - await asyncio.sleep(self.poll_interval) - - # If we've exhausted all polling attempts - raise Exception( - f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " - "The operation may still be running on the server but is taking longer than expected." - ) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index ad4029236..e60e7a6d6 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/ """ -from __future__ import annotations from inspect import cleandoc import folder_paths as comfy_paths -import aiohttp import os -import asyncio import logging import math from typing import Optional @@ -26,11 +23,11 @@ from comfy_api_nodes.apis.rodin_api import ( Rodin3DDownloadResponse, JobStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( + sync_op, + poll_op, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, + download_url_to_bytesio, ) from comfy_api.latest import ComfyExtension, IO @@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): async def create_generate_task( + cls: type[IO.ComfyNode], images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", - TAPose = False, - auth_kwargs: Optional[dict[str, str]] = None, + ta_pose: bool = False, ): if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), + response_model=Rodin3DGenerateResponse, + data=Rodin3DGenerateRequest( seed=seed, tier=tier, material=material, quality_override=quality_override, mesh_mode=mesh_mode, - TAPose=TAPose, + TAPose=ta_pose, ), files=[ ( @@ -159,11 +152,8 @@ async def create_generate_task( for image in images if image is not None ], content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response = await operation.execute() - if hasattr(response, "error"): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" logging.error(error_message) @@ -187,74 +177,46 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" +def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: + if not response.jobs: + return None + completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) + return int((completed_count / len(response.jobs)) * 100) -async def poll_for_task_status( - subscription_key, auth_kwargs: Optional[dict[str, str]] = None, -) -> Rodin3DCheckStatusResponse: - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/status", - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=check_rodin_status, - poll_interval=3.0, - auth_kwargs=auth_kwargs, - ) + +async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse: logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return await poll_operation.execute() - - -async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/download", - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest(task_uuid=uuid), - auth_kwargs=auth_kwargs, + return await poll_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"), + response_model=Rodin3DCheckStatusResponse, + data=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + status_extractor=check_rodin_status, + progress_extractor=extract_progress, ) - return await operation.execute() -async def download_files(url_list, task_uuid): +async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + return await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"), + response_model=Rodin3DDownloadResponse, + data=Rodin3DDownloadRequest(task_uuid=uuid), + monitor_progress=False, + ) + + +async def download_files(url_list, task_uuid: str): result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - file_path = os.path.join(save_path, i.name) - if file_path.endswith(".glb"): - model_file_path = os.path.join(result_folder_name, i.name) - logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(i.url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e)) - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) + for i in url_list.list: + file_path = os.path.join(save_path, i.name) + if file_path.endswith(".glb"): + model_file_path = os.path.join(result_folder_name, i.name) + await download_url_to_bytesio(i.url, file_path) return model_file_path @@ -276,6 +238,7 @@ class Rodin3D_Regular(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -294,21 +257,17 @@ class Rodin3D_Regular(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -332,6 +291,7 @@ class Rodin3D_Detail(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -350,21 +310,17 @@ class Rodin3D_Detail(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -388,6 +344,7 @@ class Rodin3D_Smooth(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -400,27 +357,22 @@ class Rodin3D_Smooth(IO.ComfyNode): Material_Type, Polygon_count, ) -> IO.NodeOutput: - tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, - tier=tier, + tier="Smooth", mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -451,6 +403,7 @@ class Rodin3D_Sketch(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -461,29 +414,21 @@ class Rodin3D_Sketch(IO.ComfyNode): Images, Seed, ) -> IO.NodeOutput: - tier = "Sketch" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - material_type = "PBR" - quality_override = 18000 - mesh_mode = "Quad" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, - material=material_type, - quality_override=quality_override, - tier=tier, - mesh_mode=mesh_mode, - auth_kwargs=auth, + material="PBR", + quality_override=18000, + tier="Sketch", + mesh_mode="Quad", ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -522,6 +467,7 @@ class Rodin3D_Gen2(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -541,22 +487,18 @@ class Rodin3D_Gen2(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - TAPose=TAPose, - auth_kwargs=auth, + ta_pose=TAPose, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 65bb35f0f..2d5dcd648 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -16,9 +16,9 @@ from pydantic import BaseModel from comfy import utils from comfy_api.latest import IO -from comfy_api_nodes.apis import request_logger from server import PromptServer +from . import request_logger from ._helpers import ( default_base_url, get_auth_header, @@ -77,7 +77,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"] +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 364874bed..14207dc68 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -12,8 +12,8 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO -from comfy_api_nodes.apis import request_logger +from . import request_logger from ._helpers import ( default_base_url, get_auth_header, diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/util/request_logger.py similarity index 100% rename from comfy_api_nodes/apis/request_logger.py rename to comfy_api_nodes/util/request_logger.py index c6974d35c..ac52e2eab 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,11 +1,11 @@ from __future__ import annotations -import os import datetime +import hashlib import json import logging +import os import re -import hashlib from typing import Any import folder_paths diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 7bfc61704..632450d9b 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -13,8 +13,8 @@ from pydantic import BaseModel, Field from comfy_api.latest import IO, Input from comfy_api.util import VideoCodec, VideoContainer -from comfy_api_nodes.apis import request_logger +from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt from .client import ( ApiEndpoint,