mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
937 lines
38 KiB
Python
937 lines
38 KiB
Python
import asyncio
|
||
import contextlib
|
||
import json
|
||
import logging
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from io import BytesIO
|
||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||
from urllib.parse import urljoin, urlparse
|
||
|
||
import aiohttp
|
||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||
from pydantic import BaseModel
|
||
|
||
from comfy import utils
|
||
from comfy_api.latest import IO
|
||
from comfy_api_nodes.apis import request_logger
|
||
from server import PromptServer
|
||
|
||
from ._helpers import (
|
||
default_base_url,
|
||
get_auth_header,
|
||
get_node_id,
|
||
is_processing_interrupted,
|
||
sleep_with_interrupt,
|
||
)
|
||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||
|
||
M = TypeVar("M", bound=BaseModel)
|
||
|
||
|
||
class ApiEndpoint:
|
||
def __init__(
|
||
self,
|
||
path: str,
|
||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||
*,
|
||
query_params: Optional[dict[str, Any]] = None,
|
||
headers: Optional[dict[str, str]] = None,
|
||
):
|
||
self.path = path
|
||
self.method = method
|
||
self.query_params = query_params or {}
|
||
self.headers = headers or {}
|
||
|
||
|
||
@dataclass
|
||
class _RequestConfig:
|
||
node_cls: type[IO.ComfyNode]
|
||
endpoint: ApiEndpoint
|
||
timeout: float
|
||
content_type: str
|
||
data: Optional[dict[str, Any]]
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||
multipart_parser: Optional[Callable]
|
||
max_retries: int
|
||
retry_delay: float
|
||
retry_backoff: float
|
||
wait_label: str = "Waiting"
|
||
monitor_progress: bool = True
|
||
estimated_total: Optional[int] = None
|
||
final_label_on_success: Optional[str] = "Completed"
|
||
progress_origin_ts: Optional[float] = None
|
||
|
||
|
||
@dataclass
|
||
class _PollUIState:
|
||
started: float
|
||
status_label: str = "Queued"
|
||
is_queued: bool = True
|
||
price: Optional[float] = None
|
||
estimated_duration: Optional[int] = None
|
||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||
|
||
|
||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||
|
||
|
||
async def sync_op(
|
||
cls: type[IO.ComfyNode],
|
||
endpoint: ApiEndpoint,
|
||
*,
|
||
response_model: Type[M],
|
||
data: Optional[BaseModel] = None,
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||
content_type: str = "application/json",
|
||
timeout: float = 3600.0,
|
||
multipart_parser: Optional[Callable] = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff: float = 2.0,
|
||
wait_label: str = "Waiting for server",
|
||
estimated_duration: Optional[int] = None,
|
||
final_label_on_success: Optional[str] = "Completed",
|
||
progress_origin_ts: Optional[float] = None,
|
||
monitor_progress: bool = True,
|
||
) -> M:
|
||
raw = await sync_op_raw(
|
||
cls,
|
||
endpoint,
|
||
data=data,
|
||
files=files,
|
||
content_type=content_type,
|
||
timeout=timeout,
|
||
multipart_parser=multipart_parser,
|
||
max_retries=max_retries,
|
||
retry_delay=retry_delay,
|
||
retry_backoff=retry_backoff,
|
||
wait_label=wait_label,
|
||
estimated_duration=estimated_duration,
|
||
as_binary=False,
|
||
final_label_on_success=final_label_on_success,
|
||
progress_origin_ts=progress_origin_ts,
|
||
monitor_progress=monitor_progress,
|
||
)
|
||
if not isinstance(raw, dict):
|
||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||
return _validate_or_raise(response_model, raw)
|
||
|
||
|
||
async def poll_op(
|
||
cls: type[IO.ComfyNode],
|
||
poll_endpoint: ApiEndpoint,
|
||
*,
|
||
response_model: Type[M],
|
||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||
data: Optional[BaseModel] = None,
|
||
poll_interval: float = 5.0,
|
||
max_poll_attempts: int = 120,
|
||
timeout_per_poll: float = 120.0,
|
||
max_retries_per_poll: int = 3,
|
||
retry_delay_per_poll: float = 1.0,
|
||
retry_backoff_per_poll: float = 2.0,
|
||
estimated_duration: Optional[int] = None,
|
||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||
cancel_timeout: float = 10.0,
|
||
) -> M:
|
||
raw = await poll_op_raw(
|
||
cls,
|
||
poll_endpoint=poll_endpoint,
|
||
status_extractor=_wrap_model_extractor(response_model, status_extractor),
|
||
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
|
||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||
completed_statuses=completed_statuses,
|
||
failed_statuses=failed_statuses,
|
||
queued_statuses=queued_statuses,
|
||
data=data,
|
||
poll_interval=poll_interval,
|
||
max_poll_attempts=max_poll_attempts,
|
||
timeout_per_poll=timeout_per_poll,
|
||
max_retries_per_poll=max_retries_per_poll,
|
||
retry_delay_per_poll=retry_delay_per_poll,
|
||
retry_backoff_per_poll=retry_backoff_per_poll,
|
||
estimated_duration=estimated_duration,
|
||
cancel_endpoint=cancel_endpoint,
|
||
cancel_timeout=cancel_timeout,
|
||
)
|
||
if not isinstance(raw, dict):
|
||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||
return _validate_or_raise(response_model, raw)
|
||
|
||
|
||
async def sync_op_raw(
|
||
cls: type[IO.ComfyNode],
|
||
endpoint: ApiEndpoint,
|
||
*,
|
||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||
content_type: str = "application/json",
|
||
timeout: float = 3600.0,
|
||
multipart_parser: Optional[Callable] = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff: float = 2.0,
|
||
wait_label: str = "Waiting for server",
|
||
estimated_duration: Optional[int] = None,
|
||
as_binary: bool = False,
|
||
final_label_on_success: Optional[str] = "Completed",
|
||
progress_origin_ts: Optional[float] = None,
|
||
monitor_progress: bool = True,
|
||
) -> Union[dict[str, Any], bytes]:
|
||
"""
|
||
Make a single network request.
|
||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||
- If as_binary=True: returns bytes.
|
||
"""
|
||
if isinstance(data, BaseModel):
|
||
data = data.model_dump(exclude_none=True)
|
||
for k, v in list(data.items()):
|
||
if isinstance(v, Enum):
|
||
data[k] = v.value
|
||
cfg = _RequestConfig(
|
||
node_cls=cls,
|
||
endpoint=endpoint,
|
||
timeout=timeout,
|
||
content_type=content_type,
|
||
data=data,
|
||
files=files,
|
||
multipart_parser=multipart_parser,
|
||
max_retries=max_retries,
|
||
retry_delay=retry_delay,
|
||
retry_backoff=retry_backoff,
|
||
wait_label=wait_label,
|
||
monitor_progress=monitor_progress,
|
||
estimated_total=estimated_duration,
|
||
final_label_on_success=final_label_on_success,
|
||
progress_origin_ts=progress_origin_ts,
|
||
)
|
||
return await _request_base(cfg, expect_binary=as_binary)
|
||
|
||
|
||
async def poll_op_raw(
|
||
cls: type[IO.ComfyNode],
|
||
poll_endpoint: ApiEndpoint,
|
||
*,
|
||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||
poll_interval: float = 5.0,
|
||
max_poll_attempts: int = 120,
|
||
timeout_per_poll: float = 120.0,
|
||
max_retries_per_poll: int = 3,
|
||
retry_delay_per_poll: float = 1.0,
|
||
retry_backoff_per_poll: float = 2.0,
|
||
estimated_duration: Optional[int] = None,
|
||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||
cancel_timeout: float = 10.0,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
|
||
|
||
Uses default complete, failed and queued states assumption.
|
||
|
||
Returns the final JSON response from the poll endpoint.
|
||
"""
|
||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||
started = time.monotonic()
|
||
consumed_attempts = 0 # counts only non-queued polls
|
||
|
||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||
last_progress: Optional[int] = None
|
||
|
||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||
stop_ticker = asyncio.Event()
|
||
|
||
async def _ticker():
|
||
"""Emit a UI update every second while polling is in progress."""
|
||
try:
|
||
while not stop_ticker.is_set():
|
||
if is_processing_interrupted():
|
||
break
|
||
now = time.monotonic()
|
||
proc_elapsed = state.base_processing_elapsed + (
|
||
(now - state.active_since) if state.active_since is not None else 0.0
|
||
)
|
||
_display_time_progress(
|
||
cls,
|
||
status=state.status_label,
|
||
elapsed_seconds=int(now - state.started),
|
||
estimated_total=state.estimated_duration,
|
||
price=state.price,
|
||
is_queued=state.is_queued,
|
||
processing_elapsed_seconds=int(proc_elapsed),
|
||
)
|
||
await asyncio.sleep(1.0)
|
||
except Exception as exc:
|
||
logging.debug("Polling ticker exited: %s", exc)
|
||
|
||
ticker_task = asyncio.create_task(_ticker())
|
||
try:
|
||
while consumed_attempts < max_poll_attempts:
|
||
try:
|
||
resp_json = await sync_op_raw(
|
||
cls,
|
||
poll_endpoint,
|
||
data=data,
|
||
timeout=timeout_per_poll,
|
||
max_retries=max_retries_per_poll,
|
||
retry_delay=retry_delay_per_poll,
|
||
retry_backoff=retry_backoff_per_poll,
|
||
wait_label="Checking",
|
||
estimated_duration=None,
|
||
as_binary=False,
|
||
final_label_on_success=None,
|
||
monitor_progress=False,
|
||
)
|
||
if not isinstance(resp_json, dict):
|
||
raise Exception("Polling endpoint returned non-JSON response.")
|
||
except ProcessingInterrupted:
|
||
if cancel_endpoint:
|
||
with contextlib.suppress(Exception):
|
||
await sync_op_raw(
|
||
cls,
|
||
cancel_endpoint,
|
||
timeout=cancel_timeout,
|
||
max_retries=0,
|
||
wait_label="Cancelling task",
|
||
estimated_duration=None,
|
||
as_binary=False,
|
||
final_label_on_success=None,
|
||
monitor_progress=False,
|
||
)
|
||
raise
|
||
|
||
try:
|
||
status = _normalize_status_value(status_extractor(resp_json))
|
||
except Exception as e:
|
||
logging.error("Status extraction failed: %s", e)
|
||
status = None
|
||
|
||
if price_extractor:
|
||
new_price = price_extractor(resp_json)
|
||
if new_price is not None:
|
||
state.price = new_price
|
||
|
||
if progress_extractor:
|
||
new_progress = progress_extractor(resp_json)
|
||
if new_progress is not None and last_progress != new_progress:
|
||
progress_bar.update_absolute(new_progress, total=100)
|
||
last_progress = new_progress
|
||
|
||
now_ts = time.monotonic()
|
||
is_queued = status in queued_states
|
||
|
||
if is_queued:
|
||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||
state.base_processing_elapsed += now_ts - state.active_since
|
||
state.active_since = None
|
||
else:
|
||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||
state.active_since = now_ts
|
||
|
||
state.is_queued = is_queued
|
||
state.status_label = status or ("Queued" if is_queued else "Processing")
|
||
if status in completed_states:
|
||
if state.active_since is not None:
|
||
state.base_processing_elapsed += now_ts - state.active_since
|
||
state.active_since = None
|
||
stop_ticker.set()
|
||
with contextlib.suppress(Exception):
|
||
await ticker_task
|
||
|
||
if progress_bar and last_progress != 100:
|
||
progress_bar.update_absolute(100, total=100)
|
||
|
||
_display_time_progress(
|
||
cls,
|
||
status=status if status else "Completed",
|
||
elapsed_seconds=int(now_ts - started),
|
||
estimated_total=estimated_duration,
|
||
price=state.price,
|
||
is_queued=False,
|
||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||
)
|
||
return resp_json
|
||
|
||
if status in failed_states:
|
||
msg = f"Task failed: {json.dumps(resp_json)}"
|
||
logging.error(msg)
|
||
raise Exception(msg)
|
||
|
||
try:
|
||
await sleep_with_interrupt(poll_interval, cls, None, None, None)
|
||
except ProcessingInterrupted:
|
||
if cancel_endpoint:
|
||
with contextlib.suppress(Exception):
|
||
await sync_op_raw(
|
||
cls,
|
||
cancel_endpoint,
|
||
timeout=cancel_timeout,
|
||
max_retries=0,
|
||
wait_label="Cancelling task",
|
||
estimated_duration=None,
|
||
as_binary=False,
|
||
final_label_on_success=None,
|
||
monitor_progress=False,
|
||
)
|
||
raise
|
||
if not is_queued:
|
||
consumed_attempts += 1
|
||
|
||
raise Exception(
|
||
f"Polling timed out after {max_poll_attempts} non-queued attempts "
|
||
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
|
||
)
|
||
except ProcessingInterrupted:
|
||
raise
|
||
except (LocalNetworkError, ApiServerError):
|
||
raise
|
||
except Exception as e:
|
||
raise Exception(f"Polling aborted due to error: {e}") from e
|
||
finally:
|
||
stop_ticker.set()
|
||
with contextlib.suppress(Exception):
|
||
await ticker_task
|
||
|
||
|
||
def _display_text(
|
||
node_cls: type[IO.ComfyNode],
|
||
text: Optional[str],
|
||
*,
|
||
status: Optional[Union[str, int]] = None,
|
||
price: Optional[float] = None,
|
||
) -> None:
|
||
display_lines: list[str] = []
|
||
if status:
|
||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||
if price is not None:
|
||
display_lines.append(f"Price: ${float(price):,.4f}")
|
||
if text is not None:
|
||
display_lines.append(text)
|
||
if display_lines:
|
||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||
|
||
|
||
def _display_time_progress(
|
||
node_cls: type[IO.ComfyNode],
|
||
status: Optional[Union[str, int]],
|
||
elapsed_seconds: int,
|
||
estimated_total: Optional[int] = None,
|
||
*,
|
||
price: Optional[float] = None,
|
||
is_queued: Optional[bool] = None,
|
||
processing_elapsed_seconds: Optional[int] = None,
|
||
) -> None:
|
||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||
remaining = max(0, int(estimated_total) - int(pe))
|
||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||
else:
|
||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||
_display_text(node_cls, time_line, status=status, price=price)
|
||
|
||
|
||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||
results = {
|
||
"internet_accessible": False,
|
||
"api_accessible": False,
|
||
}
|
||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
with contextlib.suppress(ClientError, OSError):
|
||
async with session.get("https://www.google.com") as resp:
|
||
results["internet_accessible"] = resp.status < 500
|
||
if not results["internet_accessible"]:
|
||
return results
|
||
|
||
parsed = urlparse(default_base_url())
|
||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||
with contextlib.suppress(ClientError, OSError):
|
||
async with session.get(health_url) as resp:
|
||
results["api_accessible"] = resp.status < 500
|
||
return results
|
||
|
||
|
||
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||
"""Normalize (filename, value, content_type)."""
|
||
if len(t) == 2:
|
||
return t[0], t[1], "application/octet-stream"
|
||
if len(t) == 3:
|
||
return t[0], t[1], t[2]
|
||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||
|
||
|
||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||
params = dict(endpoint_params or {})
|
||
if method.upper() == "GET" and data:
|
||
for k, v in data.items():
|
||
if v is not None:
|
||
params[k] = v
|
||
return params
|
||
|
||
|
||
def _friendly_http_message(status: int, body: Any) -> str:
|
||
if status == 401:
|
||
return "Unauthorized: Please login first to use this node."
|
||
if status == 402:
|
||
return "Payment Required: Please add credits to your account to use this node."
|
||
if status == 409:
|
||
return "There is a problem with your account. Please contact support@comfy.org."
|
||
if status == 429:
|
||
return "Rate Limit Exceeded: Please try again later."
|
||
try:
|
||
if isinstance(body, dict):
|
||
err = body.get("error")
|
||
if isinstance(err, dict):
|
||
msg = err.get("message")
|
||
typ = err.get("type")
|
||
if msg and typ:
|
||
return f"API Error: {msg} (Type: {typ})"
|
||
if msg:
|
||
return f"API Error: {msg}"
|
||
return f"API Error: {json.dumps(body)}"
|
||
else:
|
||
txt = str(body)
|
||
if len(txt) <= 200:
|
||
return f"API Error (raw): {txt}"
|
||
return f"API Error (status {status})"
|
||
except Exception:
|
||
return f"HTTP {status}: Unknown error"
|
||
|
||
|
||
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||
slug = path.strip("/").replace("/", "_") or "op"
|
||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||
|
||
|
||
def _snapshot_request_body_for_logging(
|
||
content_type: str,
|
||
method: str,
|
||
data: Optional[dict[str, Any]],
|
||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||
) -> Optional[Union[dict[str, Any], str]]:
|
||
if method.upper() == "GET":
|
||
return None
|
||
if content_type == "multipart/form-data":
|
||
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
|
||
file_fields: list[dict[str, str]] = []
|
||
if files:
|
||
file_iter = files if isinstance(files, list) else list(files.items())
|
||
for field_name, file_obj in file_iter:
|
||
if file_obj is None:
|
||
continue
|
||
if isinstance(file_obj, tuple):
|
||
filename = file_obj[0]
|
||
else:
|
||
filename = getattr(file_obj, "name", field_name)
|
||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||
if content_type == "application/x-www-form-urlencoded":
|
||
return data or {}
|
||
return data or {}
|
||
|
||
|
||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||
url = cfg.endpoint.path
|
||
parsed_url = urlparse(url)
|
||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||
|
||
method = cfg.endpoint.method
|
||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||
|
||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||
"""Every second: update elapsed time and signal interruption."""
|
||
try:
|
||
while not stop_evt.is_set():
|
||
if is_processing_interrupted():
|
||
return
|
||
if cfg.monitor_progress:
|
||
_display_time_progress(
|
||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||
)
|
||
await asyncio.sleep(1.0)
|
||
except asyncio.CancelledError:
|
||
return # normal shutdown
|
||
|
||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||
attempt = 0
|
||
delay = cfg.retry_delay
|
||
operation_succeeded: bool = False
|
||
final_elapsed_seconds: Optional[int] = None
|
||
while True:
|
||
attempt += 1
|
||
stop_event = asyncio.Event()
|
||
monitor_task: Optional[asyncio.Task] = None
|
||
sess: Optional[aiohttp.ClientSession] = None
|
||
|
||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||
|
||
payload_headers = {"Accept": "*/*"}
|
||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||
if cfg.endpoint.headers:
|
||
payload_headers.update(cfg.endpoint.headers)
|
||
|
||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||
if method == "GET":
|
||
payload_headers.pop("Content-Type", None)
|
||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||
try:
|
||
if cfg.monitor_progress:
|
||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||
|
||
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
|
||
sess = aiohttp.ClientSession(timeout=timeout)
|
||
|
||
if cfg.content_type == "multipart/form-data" and method != "GET":
|
||
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
|
||
payload_headers.pop("Content-Type", None)
|
||
if cfg.multipart_parser and cfg.data:
|
||
form = cfg.multipart_parser(cfg.data)
|
||
if not isinstance(form, aiohttp.FormData):
|
||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||
else:
|
||
form = aiohttp.FormData(default_to_multipart=True)
|
||
if cfg.data:
|
||
for k, v in cfg.data.items():
|
||
if v is None:
|
||
continue
|
||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||
if cfg.files:
|
||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||
for field_name, file_obj in file_iter:
|
||
if file_obj is None:
|
||
continue
|
||
if isinstance(file_obj, tuple):
|
||
filename, file_value, content_type = _unpack_tuple(file_obj)
|
||
else:
|
||
filename = getattr(file_obj, "name", field_name)
|
||
file_value = file_obj
|
||
content_type = "application/octet-stream"
|
||
# Attempt to rewind BytesIO for retries
|
||
if isinstance(file_value, BytesIO):
|
||
with contextlib.suppress(Exception):
|
||
file_value.seek(0)
|
||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||
payload_kw["data"] = form
|
||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||
payload_kw["data"] = cfg.data or {}
|
||
elif method != "GET":
|
||
payload_headers["Content-Type"] = "application/json"
|
||
payload_kw["json"] = cfg.data or {}
|
||
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||
|
||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||
req_task = asyncio.create_task(req_coro)
|
||
|
||
# Race: request vs. monitor (interruption)
|
||
tasks = {req_task}
|
||
if monitor_task:
|
||
tasks.add(monitor_task)
|
||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||
|
||
if monitor_task and monitor_task in done:
|
||
# Interrupted – cancel the request and abort
|
||
if req_task in pending:
|
||
req_task.cancel()
|
||
raise ProcessingInterrupted("Task cancelled")
|
||
|
||
# Otherwise, request finished
|
||
resp = await req_task
|
||
async with resp:
|
||
if resp.status >= 400:
|
||
try:
|
||
body = await resp.json()
|
||
except (ContentTypeError, json.JSONDecodeError):
|
||
body = await resp.text()
|
||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||
logging.warning(
|
||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||
method,
|
||
url,
|
||
resp.status,
|
||
delay,
|
||
attempt,
|
||
cfg.max_retries,
|
||
)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=body,
|
||
error_message=_friendly_http_message(resp.status, body),
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
|
||
await sleep_with_interrupt(
|
||
delay,
|
||
cfg.node_cls,
|
||
cfg.wait_label if cfg.monitor_progress else None,
|
||
start_time if cfg.monitor_progress else None,
|
||
cfg.estimated_total,
|
||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||
)
|
||
delay *= cfg.retry_backoff
|
||
continue
|
||
msg = _friendly_http_message(resp.status, body)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=body,
|
||
error_message=msg,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
raise Exception(msg)
|
||
|
||
if expect_binary:
|
||
buff = bytearray()
|
||
last_tick = time.monotonic()
|
||
async for chunk in resp.content.iter_chunked(64 * 1024):
|
||
buff.extend(chunk)
|
||
now = time.monotonic()
|
||
if now - last_tick >= 1.0:
|
||
last_tick = now
|
||
if is_processing_interrupted():
|
||
raise ProcessingInterrupted("Task cancelled")
|
||
if cfg.monitor_progress:
|
||
_display_time_progress(
|
||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||
)
|
||
bytes_payload = bytes(buff)
|
||
operation_succeeded = True
|
||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=bytes_payload,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
return bytes_payload
|
||
else:
|
||
try:
|
||
payload = await resp.json()
|
||
response_content_to_log: Any = payload
|
||
except (ContentTypeError, json.JSONDecodeError):
|
||
text = await resp.text()
|
||
try:
|
||
payload = json.loads(text) if text else {}
|
||
except json.JSONDecodeError:
|
||
payload = {"_raw": text}
|
||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||
operation_succeeded = True
|
||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=response_content_to_log,
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||
return payload
|
||
|
||
except ProcessingInterrupted:
|
||
logging.debug("Polling was interrupted by user")
|
||
raise
|
||
except (ClientError, OSError) as e:
|
||
if attempt <= cfg.max_retries:
|
||
logging.warning(
|
||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||
method,
|
||
url,
|
||
delay,
|
||
attempt,
|
||
cfg.max_retries,
|
||
str(e),
|
||
)
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||
await sleep_with_interrupt(
|
||
delay,
|
||
cfg.node_cls,
|
||
cfg.wait_label if cfg.monitor_progress else None,
|
||
start_time if cfg.monitor_progress else None,
|
||
cfg.estimated_total,
|
||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||
)
|
||
delay *= cfg.retry_backoff
|
||
continue
|
||
diag = await _diagnose_connectivity()
|
||
if not diag["internet_accessible"]:
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
error_message=f"LocalNetworkError: {str(e)}",
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||
raise LocalNetworkError(
|
||
"Unable to connect to the API server due to local network issues. "
|
||
"Please check your internet connection and try again."
|
||
) from e
|
||
try:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=dict(payload_headers) if payload_headers else None,
|
||
request_params=dict(params) if params else None,
|
||
request_data=request_body_log,
|
||
error_message=f"ApiServerError: {str(e)}",
|
||
)
|
||
except Exception as _log_e:
|
||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||
raise ApiServerError(
|
||
f"The API server at {default_base_url()} is currently unreachable. "
|
||
f"The service may be experiencing issues."
|
||
) from e
|
||
finally:
|
||
stop_event.set()
|
||
if monitor_task:
|
||
monitor_task.cancel()
|
||
with contextlib.suppress(Exception):
|
||
await monitor_task
|
||
if sess:
|
||
with contextlib.suppress(Exception):
|
||
await sess.close()
|
||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||
_display_time_progress(
|
||
cfg.node_cls,
|
||
status=cfg.final_label_on_success,
|
||
elapsed_seconds=(
|
||
final_elapsed_seconds
|
||
if final_elapsed_seconds is not None
|
||
else int(time.monotonic() - start_time)
|
||
),
|
||
estimated_total=cfg.estimated_total,
|
||
price=None,
|
||
is_queued=False,
|
||
processing_elapsed_seconds=final_elapsed_seconds,
|
||
)
|
||
|
||
|
||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||
try:
|
||
return response_model.model_validate(payload)
|
||
except Exception as e:
|
||
logging.error(
|
||
"Response validation failed for %s: %s",
|
||
getattr(response_model, "__name__", response_model),
|
||
e,
|
||
)
|
||
raise Exception(
|
||
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
|
||
) from e
|
||
|
||
|
||
def _wrap_model_extractor(
|
||
response_model: Type[M],
|
||
extractor: Optional[Callable[[M], Any]],
|
||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||
Validates the dict into `response_model` before invoking `extractor`.
|
||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||
the same response for multiple extractors in a single poll attempt.
|
||
"""
|
||
if extractor is None:
|
||
return None
|
||
_cache: dict[int, M] = {}
|
||
|
||
def _wrapped(d: dict[str, Any]) -> Any:
|
||
try:
|
||
key = id(d)
|
||
model = _cache.get(key)
|
||
if model is None:
|
||
model = response_model.model_validate(d)
|
||
_cache[key] = model
|
||
return extractor(model)
|
||
except Exception as e:
|
||
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
|
||
raise
|
||
|
||
return _wrapped
|
||
|
||
|
||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||
if not values:
|
||
return set()
|
||
out: set[Union[str, int]] = set()
|
||
for v in values:
|
||
nv = _normalize_status_value(v)
|
||
if nv is not None:
|
||
out.add(nv)
|
||
return out
|
||
|
||
|
||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||
if isinstance(val, str):
|
||
return val.strip().lower()
|
||
return val
|