[API Nodes]: fixes and refactor (#11104)

* chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder

* chore(api-nodes): add validate_video_frame_count function from LTX PR

* chore(api-nodes): replace deprecated V1 imports

* fix(api-nodes): the types returned by the "poll_op" function are now correct.
This commit is contained in:
Alexander Piskun 2025-12-05 00:05:28 +02:00 committed by GitHub
parent 9bc893c5bb
commit 3c8456223c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 146 additions and 135 deletions

View File

@ -47,6 +47,7 @@ from .validation_utils import (
validate_string, validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
validate_video_frame_count,
) )
__all__ = [ __all__ = [
@ -94,6 +95,7 @@ __all__ = [
"validate_string", "validate_string",
"validate_video_dimensions", "validate_video_dimensions",
"validate_video_duration", "validate_video_duration",
"validate_video_frame_count",
# Misc functions # Misc functions
"get_fs_object_size", "get_fs_object_size",
] ]

View File

@ -2,8 +2,8 @@ import asyncio
import contextlib import contextlib
import os import os
import time import time
from collections.abc import Callable
from io import BytesIO from io import BytesIO
from typing import Callable, Optional, Union
from comfy.cli_args import args from comfy.cli_args import args
from comfy.model_management import processing_interrupted from comfy.model_management import processing_interrupted
@ -35,12 +35,12 @@ def default_base_url() -> str:
async def sleep_with_interrupt( async def sleep_with_interrupt(
seconds: float, seconds: float,
node_cls: Optional[type[IO.ComfyNode]], node_cls: type[IO.ComfyNode] | None,
label: Optional[str] = None, label: str | None = None,
start_ts: Optional[float] = None, start_ts: float | None = None,
estimated_total: Optional[int] = None, estimated_total: int | None = None,
*, *,
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
): ):
""" """
Sleep in 1s slices while: Sleep in 1s slices while:
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower() return mime_type.split("/")[-1].lower()
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str): if isinstance(path_or_object, str):
return os.path.getsize(path_or_object) return os.path.getsize(path_or_object)
return len(path_or_object.getvalue()) return len(path_or_object.getvalue())

View File

@ -4,10 +4,11 @@ import json
import logging import logging
import time import time
import uuid import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union from typing import Any, Literal, TypeVar
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import aiohttp import aiohttp
@ -37,8 +38,8 @@ class ApiEndpoint:
path: str, path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*, *,
query_params: Optional[dict[str, Any]] = None, query_params: dict[str, Any] | None = None,
headers: Optional[dict[str, str]] = None, headers: dict[str, str] | None = None,
): ):
self.path = path self.path = path
self.method = method self.method = method
@ -52,18 +53,18 @@ class _RequestConfig:
endpoint: ApiEndpoint endpoint: ApiEndpoint
timeout: float timeout: float
content_type: str content_type: str
data: Optional[dict[str, Any]] data: dict[str, Any] | None
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Optional[Callable] multipart_parser: Callable | None
max_retries: int max_retries: int
retry_delay: float retry_delay: float
retry_backoff: float retry_backoff: float
wait_label: str = "Waiting" wait_label: str = "Waiting"
monitor_progress: bool = True monitor_progress: bool = True
estimated_total: Optional[int] = None estimated_total: int | None = None
final_label_on_success: Optional[str] = "Completed" final_label_on_success: str | None = "Completed"
progress_origin_ts: Optional[float] = None progress_origin_ts: float | None = None
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None
@dataclass @dataclass
@ -71,10 +72,10 @@ class _PollUIState:
started: float started: float
status_label: str = "Queued" status_label: str = "Queued"
is_queued: bool = True is_queued: bool = True
price: Optional[float] = None price: float | None = None
estimated_duration: Optional[int] = None estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: Optional[float] = None # start time of current active interval (None if queued) active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504} _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@ -87,20 +88,20 @@ async def sync_op(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
endpoint: ApiEndpoint, endpoint: ApiEndpoint,
*, *,
response_model: Type[M], response_model: type[M],
price_extractor: Optional[Callable[[M], Optional[float]]] = None, price_extractor: Callable[[M | Any], float | None] | None = None,
data: Optional[BaseModel] = None, data: BaseModel | None = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json", content_type: str = "application/json",
timeout: float = 3600.0, timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None, multipart_parser: Callable | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: str = "Waiting for server", wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
final_label_on_success: Optional[str] = "Completed", final_label_on_success: str | None = "Completed",
progress_origin_ts: Optional[float] = None, progress_origin_ts: float | None = None,
monitor_progress: bool = True, monitor_progress: bool = True,
) -> M: ) -> M:
raw = await sync_op_raw( raw = await sync_op_raw(
@ -131,22 +132,22 @@ async def poll_op(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint, poll_endpoint: ApiEndpoint,
*, *,
response_model: Type[M], response_model: type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]], status_extractor: Callable[[M | Any], str | int | None],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None, progress_extractor: Callable[[M | Any], int | None] | None = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None, price_extractor: Callable[[M | Any], float | None] | None = None,
completed_statuses: Optional[list[Union[str, int]]] = None, completed_statuses: list[str | int] | None = None,
failed_statuses: Optional[list[Union[str, int]]] = None, failed_statuses: list[str | int] | None = None,
queued_statuses: Optional[list[Union[str, int]]] = None, queued_statuses: list[str | int] | None = None,
data: Optional[BaseModel] = None, data: BaseModel | None = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
cancel_endpoint: Optional[ApiEndpoint] = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
) -> M: ) -> M:
raw = await poll_op_raw( raw = await poll_op_raw(
@ -178,22 +179,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
endpoint: ApiEndpoint, endpoint: ApiEndpoint,
*, *,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None, data: dict[str, Any] | BaseModel | None = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json", content_type: str = "application/json",
timeout: float = 3600.0, timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None, multipart_parser: Callable | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: str = "Waiting for server", wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
as_binary: bool = False, as_binary: bool = False,
final_label_on_success: Optional[str] = "Completed", final_label_on_success: str | None = "Completed",
progress_origin_ts: Optional[float] = None, progress_origin_ts: float | None = None,
monitor_progress: bool = True, monitor_progress: bool = True,
) -> Union[dict[str, Any], bytes]: ) -> dict[str, Any] | bytes:
""" """
Make a single network request. Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
@ -229,21 +230,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint, poll_endpoint: ApiEndpoint,
*, *,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], status_extractor: Callable[[dict[str, Any]], str | int | None],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
completed_statuses: Optional[list[Union[str, int]]] = None, completed_statuses: list[str | int] | None = None,
failed_statuses: Optional[list[Union[str, int]]] = None, failed_statuses: list[str | int] | None = None,
queued_statuses: Optional[list[Union[str, int]]] = None, queued_statuses: list[str | int] | None = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None, data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
cancel_endpoint: Optional[ApiEndpoint] = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
@ -261,7 +262,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: Optional[int] = None last_progress: int | None = None
state = _PollUIState(started=started, estimated_duration=estimated_duration) state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event() stop_ticker = asyncio.Event()
@ -420,10 +421,10 @@ async def poll_op_raw(
def _display_text( def _display_text(
node_cls: type[IO.ComfyNode], node_cls: type[IO.ComfyNode],
text: Optional[str], text: str | None,
*, *,
status: Optional[Union[str, int]] = None, status: str | int | None = None,
price: Optional[float] = None, price: float | None = None,
) -> None: ) -> None:
display_lines: list[str] = [] display_lines: list[str] = []
if status: if status:
@ -440,13 +441,13 @@ def _display_text(
def _display_time_progress( def _display_time_progress(
node_cls: type[IO.ComfyNode], node_cls: type[IO.ComfyNode],
status: Optional[Union[str, int]], status: str | int | None,
elapsed_seconds: int, elapsed_seconds: int,
estimated_total: Optional[int] = None, estimated_total: int | None = None,
*, *,
price: Optional[float] = None, price: float | None = None,
is_queued: Optional[bool] = None, is_queued: bool | None = None,
processing_elapsed_seconds: Optional[int] = None, processing_elapsed_seconds: int | None = None,
) -> None: ) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False: if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])") raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {}) params = dict(endpoint_params or {})
if method.upper() == "GET" and data: if method.upper() == "GET" and data:
for k, v in data.items(): for k, v in data.items():
@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging( def _snapshot_request_body_for_logging(
content_type: str, content_type: str,
method: str, method: str,
data: Optional[dict[str, Any]], data: dict[str, Any] | None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], files: dict[str, Any] | list[tuple[str, Any]] | None,
) -> Optional[Union[dict[str, Any], str]]: ) -> dict[str, Any] | str | None:
if method.upper() == "GET": if method.upper() == "GET":
return None return None
if content_type == "multipart/form-data": if content_type == "multipart/form-data":
@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0 attempt = 0
delay = cfg.retry_delay delay = cfg.retry_delay
operation_succeeded: bool = False operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None final_elapsed_seconds: int | None = None
extracted_price: Optional[float] = None extracted_price: float | None = None
while True: while True:
attempt += 1 attempt += 1
stop_event = asyncio.Event() stop_event = asyncio.Event()
monitor_task: Optional[asyncio.Task] = None monitor_task: asyncio.Task | None = None
sess: Optional[aiohttp.ClientSession] = None sess: aiohttp.ClientSession | None = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
) )
def _validate_or_raise(response_model: Type[M], payload: Any) -> M: def _validate_or_raise(response_model: type[M], payload: Any) -> M:
try: try:
return response_model.model_validate(payload) return response_model.model_validate(payload)
except Exception as e: except Exception as e:
@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _wrap_model_extractor( def _wrap_model_extractor(
response_model: Type[M], response_model: type[M],
extractor: Optional[Callable[[M], Any]], extractor: Callable[[M], Any] | None,
) -> Optional[Callable[[dict[str, Any]], Any]]: ) -> Callable[[dict[str, Any]], Any] | None:
"""Wrap a typed extractor so it can be used by the dict-based poller. """Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`. Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@ -929,10 +930,10 @@ def _wrap_model_extractor(
return _wrapped return _wrapped
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
if not values: if not values:
return set() return set()
out: set[Union[str, int]] = set() out: set[str | int] = set()
for v in values: for v in values:
nv = _normalize_status_value(v) nv = _normalize_status_value(v)
if nv is not None: if nv is not None:
@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
return out return out
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: def _normalize_status_value(val: str | int | None) -> str | int | None:
if isinstance(val, str): if isinstance(val, str):
return val.strip().lower() return val.strip().lower()
return val return val

View File

@ -4,7 +4,6 @@ import math
import mimetypes import mimetypes
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional
import av import av
import numpy as np import numpy as np
@ -12,8 +11,7 @@ import torch
from PIL import Image from PIL import Image
from comfy.utils import common_upscale from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl from comfy_api.latest import Input, InputImpl, Types
from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension from ._helpers import mimetype_to_extension
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio( def tensor_to_bytesio(
image: torch.Tensor, image: torch.Tensor,
name: Optional[str] = None, name: str | None = None,
total_pixels: int = 2048 * 2048, total_pixels: int = 2048 * 2048,
mime_type: str = "image/png", mime_type: str = "image/png",
) -> BytesIO: ) -> BytesIO:
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
def video_to_base64_string( def video_to_base64_string(
video: Input.Video, video: Input.Video,
container_format: VideoContainer = None, container_format: Types.VideoContainer | None = None,
codec: VideoCodec = None codec: Types.VideoCodec | None = None,
) -> str: ) -> str:
""" """
Converts a video input to a base64 string. Converts a video input to a base64 string.
@ -189,12 +187,11 @@ def video_to_base64_string(
codec: Optional codec to use (defaults to video.codec if available) codec: Optional codec to use (defaults to video.codec if available)
""" """
video_bytes_io = BytesIO() video_bytes_io = BytesIO()
video.save_to(
# Use provided format/codec if specified, otherwise use video's own if available video_bytes_io,
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0) video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")

View File

@ -3,15 +3,15 @@ import contextlib
import uuid import uuid
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import IO, Optional, Union from typing import IO
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import aiohttp import aiohttp
import torch import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO from comfy_api.latest import IO as COMFY_IO
from comfy_api.latest import InputImpl
from . import request_logger from . import request_logger
from ._helpers import ( from ._helpers import (
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio( async def download_url_to_bytesio(
url: str, url: str,
dest: Optional[Union[BytesIO, IO[bytes], str, Path]], dest: BytesIO | IO[bytes] | str | Path | None,
*, *,
timeout: Optional[float] = None, timeout: float | None = None,
max_retries: int = 5, max_retries: int = 5,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
is_path_sink = isinstance(dest, (str, Path)) is_path_sink = isinstance(dest, (str, Path))
fhandle = None fhandle = None
session: Optional[aiohttp.ClientSession] = None session: aiohttp.ClientSession | None = None
stop_evt: Optional[asyncio.Event] = None stop_evt: asyncio.Event | None = None
monitor_task: Optional[asyncio.Task] = None monitor_task: asyncio.Task | None = None
req_task: Optional[asyncio.Task] = None req_task: asyncio.Task | None = None
try: try:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@ -234,11 +234,11 @@ async def download_url_to_video_output(
timeout: float = None, timeout: float = None,
max_retries: int = 5, max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None, cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile: ) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.""" """Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO() result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return VideoFromFile(result) return InputImpl.VideoFromFile(result)
async def download_url_as_bytesio( async def download_url_as_bytesio(

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import datetime import datetime
import hashlib import hashlib
import json import json

View File

@ -4,15 +4,13 @@ import logging
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from comfy_api.latest import IO, Input from comfy_api.latest import IO, Input, Types
from comfy_api.util import VideoCodec, VideoContainer
from . import request_logger from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt from ._helpers import is_processing_interrupted, sleep_with_interrupt
@ -32,7 +30,7 @@ from .conversions import (
class UploadRequest(BaseModel): class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload") file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field( content_type: str | None = Field(
None, None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
) )
@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first. To upload multiple images, stack them in the batch dimension first.
""" """
# if batch, try to upload each file if max_images is greater than 0 # if batched, try to upload each file if max_images is greater than 0
download_urls: list[str] = [] download_urls: list[str] = []
is_batch = len(image.shape) > 3 is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1 batch_len = image.shape[0] if is_batch else 1
@ -100,9 +98,9 @@ async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
video: Input.Video, video: Input.Video,
*, *,
container: VideoContainer = VideoContainer.MP4, container: Types.VideoContainer = Types.VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264, codec: Types.VideoCodec = Types.VideoCodec.H264,
max_duration: Optional[int] = None, max_duration: int | None = None,
wait_label: str | None = "Uploading", wait_label: str | None = "Uploading",
) -> str: ) -> str:
""" """
@ -220,7 +218,7 @@ async def upload_file(
return return
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
sess: Optional[aiohttp.ClientSession] = None sess: aiohttp.ClientSession | None = None
try: try:
try: try:
request_logger.log_request_response( request_logger.log_request_response(

View File

@ -1,9 +1,7 @@
import logging import logging
from typing import Optional
import torch import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input from comfy_api.latest import Input
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
def validate_image_dimensions( def validate_image_dimensions(
image: torch.Tensor, image: torch.Tensor,
min_width: Optional[int] = None, min_width: int | None = None,
max_width: Optional[int] = None, max_width: int | None = None,
min_height: Optional[int] = None, min_height: int | None = None,
max_height: Optional[int] = None, max_height: int | None = None,
): ):
height, width = get_image_dimensions(image) height, width = get_image_dimensions(image)
@ -37,8 +35,8 @@ def validate_image_dimensions(
def validate_image_aspect_ratio( def validate_image_aspect_ratio(
image: torch.Tensor, image: torch.Tensor,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*, *,
strict: bool = True, # True -> (min, max); False -> [min, max] strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
def validate_aspect_ratio_string( def validate_aspect_ratio_string(
aspect_ratio: str, aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*, *,
strict: bool = False, # True -> (min, max); False -> [min, max] strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
def validate_video_dimensions( def validate_video_dimensions(
video: Input.Video, video: Input.Video,
min_width: Optional[int] = None, min_width: int | None = None,
max_width: Optional[int] = None, max_width: int | None = None,
min_height: Optional[int] = None, min_height: int | None = None,
max_height: Optional[int] = None, max_height: int | None = None,
): ):
try: try:
width, height = video.get_dimensions() width, height = video.get_dimensions()
@ -120,8 +118,8 @@ def validate_video_dimensions(
def validate_video_duration( def validate_video_duration(
video: Input.Video, video: Input.Video,
min_duration: Optional[float] = None, min_duration: float | None = None,
max_duration: Optional[float] = None, max_duration: float | None = None,
): ):
try: try:
duration = video.get_duration() duration = video.get_duration()
@ -136,6 +134,23 @@ def validate_video_duration(
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
def validate_video_frame_count(
video: Input.Video,
min_frame_count: int | None = None,
max_frame_count: int | None = None,
):
try:
frame_count = video.get_frame_count()
except Exception as e:
logging.error("Error getting frame count of video: %s", e)
return
if min_frame_count is not None and min_frame_count > frame_count:
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
if max_frame_count is not None and frame_count > max_frame_count:
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
def get_number_of_images(images): def get_number_of_images(images):
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1 return images.shape[0] if images.ndim >= 4 else 1
@ -144,8 +159,8 @@ def get_number_of_images(images):
def validate_audio_duration( def validate_audio_duration(
audio: Input.Audio, audio: Input.Audio,
min_duration: Optional[float] = None, min_duration: float | None = None,
max_duration: Optional[float] = None, max_duration: float | None = None,
) -> None: ) -> None:
sr = int(audio["sample_rate"]) sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr dur = int(audio["waveform"].shape[-1]) / sr
@ -177,7 +192,7 @@ def validate_string(
) )
def validate_container_format_is_mp4(video: VideoInput) -> None: def validate_container_format_is_mp4(video: Input.Video) -> None:
"""Validates video container format is MP4.""" """Validates video container format is MP4."""
container_format = video.get_container_format() container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
def _assert_ratio_bounds( def _assert_ratio_bounds(
ar: float, ar: float,
*, *,
min_ratio: Optional[tuple[float, float]] = None, min_ratio: tuple[float, float] | None = None,
max_ratio: Optional[tuple[float, float]] = None, max_ratio: tuple[float, float] | None = None,
strict: bool = True, strict: bool = True,
) -> None: ) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds.""" """Validate a numeric aspect ratio against optional min/max ratio bounds."""