mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
[API Nodes]: fixes and refactor (#11104)
* chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder * chore(api-nodes): add validate_video_frame_count function from LTX PR * chore(api-nodes): replace deprecated V1 imports * fix(api-nodes): the types returned by the "poll_op" function are now correct.
This commit is contained in:
parent
9bc893c5bb
commit
3c8456223c
@ -47,6 +47,7 @@ from .validation_utils import (
|
|||||||
validate_string,
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
|
validate_video_frame_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -94,6 +95,7 @@ __all__ = [
|
|||||||
"validate_string",
|
"validate_string",
|
||||||
"validate_video_dimensions",
|
"validate_video_dimensions",
|
||||||
"validate_video_duration",
|
"validate_video_duration",
|
||||||
|
"validate_video_frame_count",
|
||||||
# Misc functions
|
# Misc functions
|
||||||
"get_fs_object_size",
|
"get_fs_object_size",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Callable, Optional, Union
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.model_management import processing_interrupted
|
from comfy.model_management import processing_interrupted
|
||||||
@ -35,12 +35,12 @@ def default_base_url() -> str:
|
|||||||
|
|
||||||
async def sleep_with_interrupt(
|
async def sleep_with_interrupt(
|
||||||
seconds: float,
|
seconds: float,
|
||||||
node_cls: Optional[type[IO.ComfyNode]],
|
node_cls: type[IO.ComfyNode] | None,
|
||||||
label: Optional[str] = None,
|
label: str | None = None,
|
||||||
start_ts: Optional[float] = None,
|
start_ts: float | None = None,
|
||||||
estimated_total: Optional[int] = None,
|
estimated_total: int | None = None,
|
||||||
*,
|
*,
|
||||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sleep in 1s slices while:
|
Sleep in 1s slices while:
|
||||||
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
|||||||
return mime_type.split("/")[-1].lower()
|
return mime_type.split("/")[-1].lower()
|
||||||
|
|
||||||
|
|
||||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||||
if isinstance(path_or_object, str):
|
if isinstance(path_or_object, str):
|
||||||
return os.path.getsize(path_or_object)
|
return os.path.getsize(path_or_object)
|
||||||
return len(path_or_object.getvalue())
|
return len(path_or_object.getvalue())
|
||||||
|
|||||||
@ -4,10 +4,11 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import Any, Literal, TypeVar
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -37,8 +38,8 @@ class ApiEndpoint:
|
|||||||
path: str,
|
path: str,
|
||||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||||
*,
|
*,
|
||||||
query_params: Optional[dict[str, Any]] = None,
|
query_params: dict[str, Any] | None = None,
|
||||||
headers: Optional[dict[str, str]] = None,
|
headers: dict[str, str] | None = None,
|
||||||
):
|
):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.method = method
|
self.method = method
|
||||||
@ -52,18 +53,18 @@ class _RequestConfig:
|
|||||||
endpoint: ApiEndpoint
|
endpoint: ApiEndpoint
|
||||||
timeout: float
|
timeout: float
|
||||||
content_type: str
|
content_type: str
|
||||||
data: Optional[dict[str, Any]]
|
data: dict[str, Any] | None
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||||
multipart_parser: Optional[Callable]
|
multipart_parser: Callable | None
|
||||||
max_retries: int
|
max_retries: int
|
||||||
retry_delay: float
|
retry_delay: float
|
||||||
retry_backoff: float
|
retry_backoff: float
|
||||||
wait_label: str = "Waiting"
|
wait_label: str = "Waiting"
|
||||||
monitor_progress: bool = True
|
monitor_progress: bool = True
|
||||||
estimated_total: Optional[int] = None
|
estimated_total: int | None = None
|
||||||
final_label_on_success: Optional[str] = "Completed"
|
final_label_on_success: str | None = "Completed"
|
||||||
progress_origin_ts: Optional[float] = None
|
progress_origin_ts: float | None = None
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -71,10 +72,10 @@ class _PollUIState:
|
|||||||
started: float
|
started: float
|
||||||
status_label: str = "Queued"
|
status_label: str = "Queued"
|
||||||
is_queued: bool = True
|
is_queued: bool = True
|
||||||
price: Optional[float] = None
|
price: float | None = None
|
||||||
estimated_duration: Optional[int] = None
|
estimated_duration: int | None = None
|
||||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
active_since: float | None = None # start time of current active interval (None if queued)
|
||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
@ -87,20 +88,20 @@ async def sync_op(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: BaseModel | None = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
timeout: float = 3600.0,
|
timeout: float = 3600.0,
|
||||||
multipart_parser: Optional[Callable] = None,
|
multipart_parser: Callable | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: str = "Waiting for server",
|
wait_label: str = "Waiting for server",
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
final_label_on_success: Optional[str] = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: Optional[float] = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await sync_op_raw(
|
raw = await sync_op_raw(
|
||||||
@ -131,22 +132,22 @@ async def poll_op(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
status_extractor: Callable[[M | Any], str | int | None],
|
||||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
progress_extractor: Callable[[M | Any], int | None] | None = None,
|
||||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
completed_statuses: list[str | int] | None = None,
|
||||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
failed_statuses: list[str | int] | None = None,
|
||||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
queued_statuses: list[str | int] | None = None,
|
||||||
data: Optional[BaseModel] = None,
|
data: BaseModel | None = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120,
|
max_poll_attempts: int = 120,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 3,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 2.0,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await poll_op_raw(
|
raw = await poll_op_raw(
|
||||||
@ -178,22 +179,22 @@ async def sync_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
endpoint: ApiEndpoint,
|
endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: dict[str, Any] | BaseModel | None = None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
timeout: float = 3600.0,
|
timeout: float = 3600.0,
|
||||||
multipart_parser: Optional[Callable] = None,
|
multipart_parser: Callable | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: str = "Waiting for server",
|
wait_label: str = "Waiting for server",
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
as_binary: bool = False,
|
as_binary: bool = False,
|
||||||
final_label_on_success: Optional[str] = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: Optional[float] = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
) -> Union[dict[str, Any], bytes]:
|
) -> dict[str, Any] | bytes:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
@ -229,21 +230,21 @@ async def poll_op_raw(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
poll_endpoint: ApiEndpoint,
|
poll_endpoint: ApiEndpoint,
|
||||||
*,
|
*,
|
||||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
|
||||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
completed_statuses: list[str | int] | None = None,
|
||||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
failed_statuses: list[str | int] | None = None,
|
||||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
queued_statuses: list[str | int] | None = None,
|
||||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
data: dict[str, Any] | BaseModel | None = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120,
|
max_poll_attempts: int = 120,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 3,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 2.0,
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -261,7 +262,7 @@ async def poll_op_raw(
|
|||||||
consumed_attempts = 0 # counts only non-queued polls
|
consumed_attempts = 0 # counts only non-queued polls
|
||||||
|
|
||||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||||
last_progress: Optional[int] = None
|
last_progress: int | None = None
|
||||||
|
|
||||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||||
stop_ticker = asyncio.Event()
|
stop_ticker = asyncio.Event()
|
||||||
@ -420,10 +421,10 @@ async def poll_op_raw(
|
|||||||
|
|
||||||
def _display_text(
|
def _display_text(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode],
|
||||||
text: Optional[str],
|
text: str | None,
|
||||||
*,
|
*,
|
||||||
status: Optional[Union[str, int]] = None,
|
status: str | int | None = None,
|
||||||
price: Optional[float] = None,
|
price: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
display_lines: list[str] = []
|
display_lines: list[str] = []
|
||||||
if status:
|
if status:
|
||||||
@ -440,13 +441,13 @@ def _display_text(
|
|||||||
|
|
||||||
def _display_time_progress(
|
def _display_time_progress(
|
||||||
node_cls: type[IO.ComfyNode],
|
node_cls: type[IO.ComfyNode],
|
||||||
status: Optional[Union[str, int]],
|
status: str | int | None,
|
||||||
elapsed_seconds: int,
|
elapsed_seconds: int,
|
||||||
estimated_total: Optional[int] = None,
|
estimated_total: int | None = None,
|
||||||
*,
|
*,
|
||||||
price: Optional[float] = None,
|
price: float | None = None,
|
||||||
is_queued: Optional[bool] = None,
|
is_queued: bool | None = None,
|
||||||
processing_elapsed_seconds: Optional[int] = None,
|
processing_elapsed_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||||
@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
|||||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||||
|
|
||||||
|
|
||||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||||
params = dict(endpoint_params or {})
|
params = dict(endpoint_params or {})
|
||||||
if method.upper() == "GET" and data:
|
if method.upper() == "GET" and data:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
|||||||
def _snapshot_request_body_for_logging(
|
def _snapshot_request_body_for_logging(
|
||||||
content_type: str,
|
content_type: str,
|
||||||
method: str,
|
method: str,
|
||||||
data: Optional[dict[str, Any]],
|
data: dict[str, Any] | None,
|
||||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
files: dict[str, Any] | list[tuple[str, Any]] | None,
|
||||||
) -> Optional[Union[dict[str, Any], str]]:
|
) -> dict[str, Any] | str | None:
|
||||||
if method.upper() == "GET":
|
if method.upper() == "GET":
|
||||||
return None
|
return None
|
||||||
if content_type == "multipart/form-data":
|
if content_type == "multipart/form-data":
|
||||||
@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
attempt = 0
|
attempt = 0
|
||||||
delay = cfg.retry_delay
|
delay = cfg.retry_delay
|
||||||
operation_succeeded: bool = False
|
operation_succeeded: bool = False
|
||||||
final_elapsed_seconds: Optional[int] = None
|
final_elapsed_seconds: int | None = None
|
||||||
extracted_price: Optional[float] = None
|
extracted_price: float | None = None
|
||||||
while True:
|
while True:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
monitor_task: Optional[asyncio.Task] = None
|
monitor_task: asyncio.Task | None = None
|
||||||
sess: Optional[aiohttp.ClientSession] = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
|
||||||
try:
|
try:
|
||||||
return response_model.model_validate(payload)
|
return response_model.model_validate(payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
|||||||
|
|
||||||
|
|
||||||
def _wrap_model_extractor(
|
def _wrap_model_extractor(
|
||||||
response_model: Type[M],
|
response_model: type[M],
|
||||||
extractor: Optional[Callable[[M], Any]],
|
extractor: Callable[[M], Any] | None,
|
||||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
) -> Callable[[dict[str, Any]], Any] | None:
|
||||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||||
Validates the dict into `response_model` before invoking `extractor`.
|
Validates the dict into `response_model` before invoking `extractor`.
|
||||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||||
@ -929,10 +930,10 @@ def _wrap_model_extractor(
|
|||||||
return _wrapped
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
|
||||||
if not values:
|
if not values:
|
||||||
return set()
|
return set()
|
||||||
out: set[Union[str, int]] = set()
|
out: set[str | int] = set()
|
||||||
for v in values:
|
for v in values:
|
||||||
nv = _normalize_status_value(v)
|
nv = _normalize_status_value(v)
|
||||||
if nv is not None:
|
if nv is not None:
|
||||||
@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
def _normalize_status_value(val: str | int | None) -> str | int | None:
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
return val.strip().lower()
|
return val.strip().lower()
|
||||||
return val
|
return val
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import math
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,8 +11,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.latest import Input, InputImpl
|
from comfy_api.latest import Input, InputImpl, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
|
|
||||||
from ._helpers import mimetype_to_extension
|
from ._helpers import mimetype_to_extension
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
|||||||
|
|
||||||
def tensor_to_bytesio(
|
def tensor_to_bytesio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
total_pixels: int = 2048 * 2048,
|
total_pixels: int = 2048 * 2048,
|
||||||
mime_type: str = "image/png",
|
mime_type: str = "image/png",
|
||||||
) -> BytesIO:
|
) -> BytesIO:
|
||||||
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
|
|||||||
|
|
||||||
def video_to_base64_string(
|
def video_to_base64_string(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
container_format: VideoContainer = None,
|
container_format: Types.VideoContainer | None = None,
|
||||||
codec: VideoCodec = None
|
codec: Types.VideoCodec | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Converts a video input to a base64 string.
|
Converts a video input to a base64 string.
|
||||||
@ -189,12 +187,11 @@ def video_to_base64_string(
|
|||||||
codec: Optional codec to use (defaults to video.codec if available)
|
codec: Optional codec to use (defaults to video.codec if available)
|
||||||
"""
|
"""
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
|
video.save_to(
|
||||||
# Use provided format/codec if specified, otherwise use video's own if available
|
video_bytes_io,
|
||||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
|
||||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
|
||||||
|
)
|
||||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|||||||
@ -3,15 +3,15 @@ import contextlib
|
|||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Optional, Union
|
from typing import IO
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||||
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import IO as COMFY_IO
|
from comfy_api.latest import IO as COMFY_IO
|
||||||
|
from comfy_api.latest import InputImpl
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
|||||||
|
|
||||||
async def download_url_to_bytesio(
|
async def download_url_to_bytesio(
|
||||||
url: str,
|
url: str,
|
||||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
dest: BytesIO | IO[bytes] | str | Path | None,
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: float | None = None,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
|
|||||||
|
|
||||||
is_path_sink = isinstance(dest, (str, Path))
|
is_path_sink = isinstance(dest, (str, Path))
|
||||||
fhandle = None
|
fhandle = None
|
||||||
session: Optional[aiohttp.ClientSession] = None
|
session: aiohttp.ClientSession | None = None
|
||||||
stop_evt: Optional[asyncio.Event] = None
|
stop_evt: asyncio.Event | None = None
|
||||||
monitor_task: Optional[asyncio.Task] = None
|
monitor_task: asyncio.Task | None = None
|
||||||
req_task: Optional[asyncio.Task] = None
|
req_task: asyncio.Task | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
@ -234,11 +234,11 @@ async def download_url_to_video_output(
|
|||||||
timeout: float = None,
|
timeout: float = None,
|
||||||
max_retries: int = 5,
|
max_retries: int = 5,
|
||||||
cls: type[COMFY_IO.ComfyNode] = None,
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
) -> VideoFromFile:
|
) -> InputImpl.VideoFromFile:
|
||||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||||
result = BytesIO()
|
result = BytesIO()
|
||||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||||
return VideoFromFile(result)
|
return InputImpl.VideoFromFile(result)
|
||||||
|
|
||||||
|
|
||||||
async def download_url_as_bytesio(
|
async def download_url_as_bytesio(
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
|||||||
@ -4,15 +4,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from comfy_api.latest import IO, Input
|
from comfy_api.latest import IO, Input, Types
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||||
@ -32,7 +30,7 @@ from .conversions import (
|
|||||||
|
|
||||||
class UploadRequest(BaseModel):
|
class UploadRequest(BaseModel):
|
||||||
file_name: str = Field(..., description="Filename to upload")
|
file_name: str = Field(..., description="Filename to upload")
|
||||||
content_type: Optional[str] = Field(
|
content_type: str | None = Field(
|
||||||
None,
|
None,
|
||||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||||
)
|
)
|
||||||
@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
|
|||||||
Uploads images to ComfyUI API and returns download URLs.
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
To upload multiple images, stack them in the batch dimension first.
|
To upload multiple images, stack them in the batch dimension first.
|
||||||
"""
|
"""
|
||||||
# if batch, try to upload each file if max_images is greater than 0
|
# if batched, try to upload each file if max_images is greater than 0
|
||||||
download_urls: list[str] = []
|
download_urls: list[str] = []
|
||||||
is_batch = len(image.shape) > 3
|
is_batch = len(image.shape) > 3
|
||||||
batch_len = image.shape[0] if is_batch else 1
|
batch_len = image.shape[0] if is_batch else 1
|
||||||
@ -100,9 +98,9 @@ async def upload_video_to_comfyapi(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
*,
|
*,
|
||||||
container: VideoContainer = VideoContainer.MP4,
|
container: Types.VideoContainer = Types.VideoContainer.MP4,
|
||||||
codec: VideoCodec = VideoCodec.H264,
|
codec: Types.VideoCodec = Types.VideoCodec.H264,
|
||||||
max_duration: Optional[int] = None,
|
max_duration: int | None = None,
|
||||||
wait_label: str | None = "Uploading",
|
wait_label: str | None = "Uploading",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -220,7 +218,7 @@ async def upload_file(
|
|||||||
return
|
return
|
||||||
|
|
||||||
monitor_task = asyncio.create_task(_monitor())
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
sess: Optional[aiohttp.ClientSession] = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.input.video_types import VideoInput
|
|
||||||
from comfy_api.latest import Input
|
from comfy_api.latest import Input
|
||||||
|
|
||||||
|
|
||||||
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
|||||||
|
|
||||||
def validate_image_dimensions(
|
def validate_image_dimensions(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_width: Optional[int] = None,
|
min_width: int | None = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: int | None = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: int | None = None,
|
||||||
max_height: Optional[int] = None,
|
max_height: int | None = None,
|
||||||
):
|
):
|
||||||
height, width = get_image_dimensions(image)
|
height, width = get_image_dimensions(image)
|
||||||
|
|
||||||
@ -37,8 +35,8 @@ def validate_image_dimensions(
|
|||||||
|
|
||||||
def validate_image_aspect_ratio(
|
def validate_image_aspect_ratio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||||
*,
|
*,
|
||||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -54,8 +52,8 @@ def validate_image_aspect_ratio(
|
|||||||
def validate_images_aspect_ratio_closeness(
|
def validate_images_aspect_ratio_closeness(
|
||||||
first_image: torch.Tensor,
|
first_image: torch.Tensor,
|
||||||
second_image: torch.Tensor,
|
second_image: torch.Tensor,
|
||||||
min_rel: float, # e.g. 0.8
|
min_rel: float, # e.g. 0.8
|
||||||
max_rel: float, # e.g. 1.25
|
max_rel: float, # e.g. 1.25
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
|
|||||||
|
|
||||||
def validate_aspect_ratio_string(
|
def validate_aspect_ratio_string(
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
|
|||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
min_width: Optional[int] = None,
|
min_width: int | None = None,
|
||||||
max_width: Optional[int] = None,
|
max_width: int | None = None,
|
||||||
min_height: Optional[int] = None,
|
min_height: int | None = None,
|
||||||
max_height: Optional[int] = None,
|
max_height: int | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
@ -120,8 +118,8 @@ def validate_video_dimensions(
|
|||||||
|
|
||||||
def validate_video_duration(
|
def validate_video_duration(
|
||||||
video: Input.Video,
|
video: Input.Video,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: float | None = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
duration = video.get_duration()
|
duration = video.get_duration()
|
||||||
@ -136,6 +134,23 @@ def validate_video_duration(
|
|||||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_video_frame_count(
|
||||||
|
video: Input.Video,
|
||||||
|
min_frame_count: int | None = None,
|
||||||
|
max_frame_count: int | None = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
frame_count = video.get_frame_count()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error getting frame count of video: %s", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if min_frame_count is not None and min_frame_count > frame_count:
|
||||||
|
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
|
||||||
|
if max_frame_count is not None and frame_count > max_frame_count:
|
||||||
|
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_images(images):
|
def get_number_of_images(images):
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
return images.shape[0] if images.ndim >= 4 else 1
|
return images.shape[0] if images.ndim >= 4 else 1
|
||||||
@ -144,8 +159,8 @@ def get_number_of_images(images):
|
|||||||
|
|
||||||
def validate_audio_duration(
|
def validate_audio_duration(
|
||||||
audio: Input.Audio,
|
audio: Input.Audio,
|
||||||
min_duration: Optional[float] = None,
|
min_duration: float | None = None,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
sr = int(audio["sample_rate"])
|
sr = int(audio["sample_rate"])
|
||||||
dur = int(audio["waveform"].shape[-1]) / sr
|
dur = int(audio["waveform"].shape[-1]) / sr
|
||||||
@ -177,7 +192,7 @@ def validate_string(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
def validate_container_format_is_mp4(video: Input.Video) -> None:
|
||||||
"""Validates video container format is MP4."""
|
"""Validates video container format is MP4."""
|
||||||
container_format = video.get_container_format()
|
container_format = video.get_container_format()
|
||||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
|||||||
def _assert_ratio_bounds(
|
def _assert_ratio_bounds(
|
||||||
ar: float,
|
ar: float,
|
||||||
*,
|
*,
|
||||||
min_ratio: Optional[tuple[float, float]] = None,
|
min_ratio: tuple[float, float] | None = None,
|
||||||
max_ratio: Optional[tuple[float, float]] = None,
|
max_ratio: tuple[float, float] | None = None,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user