import asyncio import contextlib import uuid from io import BytesIO from pathlib import Path from typing import IO, Optional, Union from urllib.parse import urljoin, urlparse import aiohttp import torch from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO from . import request_logger from ._helpers import ( default_base_url, get_auth_header, is_processing_interrupted, sleep_with_interrupt, ) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .conversions import bytesio_to_image_tensor _RETRY_STATUS = {408, 429, 500, 502, 503, 504} async def download_url_to_bytesio( url: str, dest: Optional[Union[BytesIO, IO[bytes], str, Path]], *, timeout: Optional[float] = None, max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, cls: type[COMFY_IO.ComfyNode] = None, ) -> None: """Stream-download a URL to `dest`. `dest` must be one of: - a BytesIO (rewound to 0 after write), - a file-like object opened in binary write mode (must implement .write()), - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded to an absolute URL and authentication headers can be applied. Raises: ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) """ if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") attempt = 0 delay = retry_delay headers: dict[str, str] = {} parsed_url = urlparse(url) if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if cls is None: raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) headers = get_auth_header(cls) while True: attempt += 1 op_id = _generate_operation_id("GET", url, attempt) timeout_cfg = aiohttp.ClientTimeout(total=timeout) is_path_sink = isinstance(dest, (str, Path)) fhandle = None session: Optional[aiohttp.ClientSession] = None stop_evt: Optional[asyncio.Event] = None monitor_task: Optional[asyncio.Task] = None req_task: Optional[asyncio.Task] = None try: with contextlib.suppress(Exception): request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) session = aiohttp.ClientSession(timeout=timeout_cfg) stop_evt = asyncio.Event() async def _monitor(): try: while not stop_evt.is_set(): if is_processing_interrupted(): return await asyncio.sleep(1.0) except asyncio.CancelledError: return monitor_task = asyncio.create_task(_monitor()) req_task = asyncio.create_task(session.get(url, headers=headers)) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) if monitor_task in done and req_task in pending: req_task.cancel() with contextlib.suppress(Exception): await req_task raise ProcessingInterrupted("Task cancelled") try: resp = await req_task except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None async with resp: if resp.status >= 400: with contextlib.suppress(Exception): try: body = await resp.json() except (ContentTypeError, ValueError): text = await resp.text() body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" request_logger.log_request_response( operation_id=op_id, request_method="GET", request_url=url, response_status_code=resp.status, response_headers=dict(resp.headers), response_content=body, error_message=f"HTTP {resp.status}", ) if resp.status in _RETRY_STATUS and attempt <= max_retries: await sleep_with_interrupt(delay, cls, None, None, None) delay *= retry_backoff continue raise Exception(f"Failed to download (HTTP {resp.status}).") if is_path_sink: p = Path(str(dest)) with contextlib.suppress(Exception): p.parent.mkdir(parents=True, exist_ok=True) fhandle = open(p, "wb") sink = fhandle else: sink = dest # BytesIO or file-like written = 0 while True: try: chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) except asyncio.TimeoutError: chunk = b"" except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None if is_processing_interrupted(): raise ProcessingInterrupted("Task cancelled") if not chunk: if resp.content.at_eof(): break continue sink.write(chunk) written += len(chunk) if isinstance(dest, BytesIO): with contextlib.suppress(Exception): dest.seek(0) with contextlib.suppress(Exception): request_logger.log_request_response( operation_id=op_id, request_method="GET", request_url=url, response_status_code=resp.status, response_headers=dict(resp.headers), response_content=f"[streamed {written} bytes to dest]", ) return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None except (ClientError, OSError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): request_logger.log_request_response( operation_id=op_id, request_method="GET", request_url=url, error_message=f"{type(e).__name__}: {str(e)} (will retry)", ) await sleep_with_interrupt(delay, cls, None, None, None) delay *= retry_backoff continue diag = await _diagnose_connectivity() if not diag["internet_accessible"]: raise LocalNetworkError( "Unable to connect to the network. Please check your internet connection and try again." ) from e raise ApiServerError("The remote service appears unreachable at this time.") from e finally: if stop_evt is not None: stop_evt.set() if monitor_task: monitor_task.cancel() with contextlib.suppress(Exception): await monitor_task if req_task and not req_task.done(): req_task.cancel() with contextlib.suppress(Exception): await req_task if session: with contextlib.suppress(Exception): await session.close() if fhandle: with contextlib.suppress(Exception): fhandle.flush() fhandle.close() async def download_url_to_image_tensor( url: str, *, timeout: float = None, cls: type[COMFY_IO.ComfyNode] = None, ) -> torch.Tensor: """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" result = BytesIO() await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) return bytesio_to_image_tensor(result) async def download_url_to_video_output( video_url: str, *, timeout: float = None, max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, ) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) return VideoFromFile(result) async def download_url_as_bytesio( url: str, *, timeout: float = None, cls: type[COMFY_IO.ComfyNode] = None, ) -> BytesIO: """Downloads content from a URL and returns a new BytesIO (rewound to 0).""" result = BytesIO() await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) return result def _generate_operation_id(method: str, url: str, attempt: int) -> str: try: parsed = urlparse(url) slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") except Exception: slug = "download" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"