mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
339 lines
12 KiB
Python
339 lines
12 KiB
Python
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from io import BytesIO
|
|
from typing import Optional, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import aiohttp
|
|
import torch
|
|
from pydantic import BaseModel, Field
|
|
|
|
from comfy_api.latest import IO, Input
|
|
from comfy_api.util import VideoCodec, VideoContainer
|
|
from comfy_api_nodes.apis import request_logger
|
|
|
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
|
from .client import (
|
|
ApiEndpoint,
|
|
_diagnose_connectivity,
|
|
_display_time_progress,
|
|
sync_op,
|
|
)
|
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
|
from .conversions import (
|
|
audio_ndarray_to_bytesio,
|
|
audio_tensor_to_contiguous_ndarray,
|
|
tensor_to_bytesio,
|
|
)
|
|
|
|
|
|
class UploadRequest(BaseModel):
|
|
file_name: str = Field(..., description="Filename to upload")
|
|
content_type: Optional[str] = Field(
|
|
None,
|
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
|
)
|
|
|
|
|
|
class UploadResponse(BaseModel):
|
|
download_url: str = Field(..., description="URL to GET uploaded file")
|
|
upload_url: str = Field(..., description="URL to PUT file to upload")
|
|
|
|
|
|
async def upload_images_to_comfyapi(
|
|
cls: type[IO.ComfyNode],
|
|
image: torch.Tensor,
|
|
*,
|
|
max_images: int = 8,
|
|
mime_type: Optional[str] = None,
|
|
wait_label: Optional[str] = "Uploading",
|
|
) -> list[str]:
|
|
"""
|
|
Uploads images to ComfyUI API and returns download URLs.
|
|
To upload multiple images, stack them in the batch dimension first.
|
|
"""
|
|
# if batch, try to upload each file if max_images is greater than 0
|
|
download_urls: list[str] = []
|
|
is_batch = len(image.shape) > 3
|
|
batch_len = image.shape[0] if is_batch else 1
|
|
|
|
for idx in range(min(batch_len, max_images)):
|
|
tensor = image[idx] if is_batch else image
|
|
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
|
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
|
download_urls.append(url)
|
|
return download_urls
|
|
|
|
|
|
async def upload_audio_to_comfyapi(
|
|
cls: type[IO.ComfyNode],
|
|
audio: Input.Audio,
|
|
*,
|
|
container_format: str = "mp4",
|
|
codec_name: str = "aac",
|
|
mime_type: str = "audio/mp4",
|
|
filename: str = "uploaded_audio.mp4",
|
|
) -> str:
|
|
"""
|
|
Uploads a single audio input to ComfyUI API and returns its download URL.
|
|
Encodes the raw waveform into the specified format before uploading.
|
|
"""
|
|
sample_rate: int = audio["sample_rate"]
|
|
waveform: torch.Tensor = audio["waveform"]
|
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
|
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
|
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
|
|
|
|
|
async def upload_video_to_comfyapi(
|
|
cls: type[IO.ComfyNode],
|
|
video: Input.Video,
|
|
*,
|
|
container: VideoContainer = VideoContainer.MP4,
|
|
codec: VideoCodec = VideoCodec.H264,
|
|
max_duration: Optional[int] = None,
|
|
) -> str:
|
|
"""
|
|
Uploads a single video to ComfyUI API and returns its download URL.
|
|
Uses the specified container and codec for saving the video before upload.
|
|
"""
|
|
if max_duration is not None:
|
|
try:
|
|
actual_duration = video.get_duration()
|
|
if actual_duration > max_duration:
|
|
raise ValueError(
|
|
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
|
)
|
|
except Exception as e:
|
|
logging.error("Error getting video duration: %s", str(e))
|
|
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
|
|
|
upload_mime_type = f"video/{container.value.lower()}"
|
|
filename = f"uploaded_video.{container.value.lower()}"
|
|
|
|
# Convert VideoInput to BytesIO using specified container/codec
|
|
video_bytes_io = BytesIO()
|
|
video.save_to(video_bytes_io, format=container, codec=codec)
|
|
video_bytes_io.seek(0)
|
|
|
|
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
|
|
|
|
|
async def upload_file_to_comfyapi(
|
|
cls: type[IO.ComfyNode],
|
|
file_bytes_io: BytesIO,
|
|
filename: str,
|
|
upload_mime_type: Optional[str],
|
|
wait_label: Optional[str] = "Uploading",
|
|
) -> str:
|
|
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
|
if upload_mime_type is None:
|
|
request_object = UploadRequest(file_name=filename)
|
|
else:
|
|
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
|
create_resp = await sync_op(
|
|
cls,
|
|
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
|
|
data=request_object,
|
|
response_model=UploadResponse,
|
|
final_label_on_success=None,
|
|
monitor_progress=False,
|
|
)
|
|
await upload_file(
|
|
cls,
|
|
create_resp.upload_url,
|
|
file_bytes_io,
|
|
content_type=upload_mime_type,
|
|
wait_label=wait_label,
|
|
)
|
|
return create_resp.download_url
|
|
|
|
|
|
async def upload_file(
|
|
cls: type[IO.ComfyNode],
|
|
upload_url: str,
|
|
file: Union[BytesIO, str],
|
|
*,
|
|
content_type: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0,
|
|
retry_backoff: float = 2.0,
|
|
wait_label: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
|
|
|
Args:
|
|
cls: Node class (provides auth context + UI progress hooks).
|
|
upload_url: Pre-signed PUT URL.
|
|
file: BytesIO or path string.
|
|
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
|
max_retries: Maximum retry attempts.
|
|
retry_delay: Initial delay in seconds.
|
|
retry_backoff: Exponential backoff factor.
|
|
wait_label: Progress label shown in Comfy UI.
|
|
|
|
Raises:
|
|
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
|
"""
|
|
if isinstance(file, BytesIO):
|
|
with contextlib.suppress(Exception):
|
|
file.seek(0)
|
|
data = file.read()
|
|
elif isinstance(file, str):
|
|
with open(file, "rb") as f:
|
|
data = f.read()
|
|
else:
|
|
raise ValueError("file must be a BytesIO or a filesystem path string")
|
|
|
|
headers: dict[str, str] = {}
|
|
skip_auto_headers: set[str] = set()
|
|
if content_type:
|
|
headers["Content-Type"] = content_type
|
|
else:
|
|
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
|
|
|
|
attempt = 0
|
|
delay = retry_delay
|
|
start_ts = time.monotonic()
|
|
op_uuid = uuid.uuid4().hex[:8]
|
|
while True:
|
|
attempt += 1
|
|
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
|
|
timeout = aiohttp.ClientTimeout(total=None)
|
|
stop_evt = asyncio.Event()
|
|
|
|
async def _monitor():
|
|
try:
|
|
while not stop_evt.is_set():
|
|
if is_processing_interrupted():
|
|
return
|
|
if wait_label:
|
|
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
|
|
await asyncio.sleep(1.0)
|
|
except asyncio.CancelledError:
|
|
return
|
|
|
|
monitor_task = asyncio.create_task(_monitor())
|
|
sess: Optional[aiohttp.ClientSession] = None
|
|
try:
|
|
try:
|
|
request_logger.log_request_response(
|
|
operation_id=operation_id,
|
|
request_method="PUT",
|
|
request_url=upload_url,
|
|
request_headers=headers or None,
|
|
request_params=None,
|
|
request_data=f"[File data {len(data)} bytes]",
|
|
)
|
|
except Exception as e:
|
|
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
|
|
|
sess = aiohttp.ClientSession(timeout=timeout)
|
|
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
|
req_task = asyncio.create_task(req)
|
|
|
|
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
if monitor_task in done and req_task in pending:
|
|
req_task.cancel()
|
|
raise ProcessingInterrupted("Upload cancelled")
|
|
|
|
try:
|
|
resp = await req_task
|
|
except asyncio.CancelledError:
|
|
raise ProcessingInterrupted("Upload cancelled") from None
|
|
|
|
async with resp:
|
|
if resp.status >= 400:
|
|
with contextlib.suppress(Exception):
|
|
try:
|
|
body = await resp.json()
|
|
except Exception:
|
|
body = await resp.text()
|
|
msg = f"Upload failed with status {resp.status}"
|
|
request_logger.log_request_response(
|
|
operation_id=operation_id,
|
|
request_method="PUT",
|
|
request_url=upload_url,
|
|
response_status_code=resp.status,
|
|
response_headers=dict(resp.headers),
|
|
response_content=body,
|
|
error_message=msg,
|
|
)
|
|
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
|
|
await sleep_with_interrupt(
|
|
delay,
|
|
cls,
|
|
wait_label,
|
|
start_ts,
|
|
None,
|
|
display_callback=_display_time_progress if wait_label else None,
|
|
)
|
|
delay *= retry_backoff
|
|
continue
|
|
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
|
try:
|
|
request_logger.log_request_response(
|
|
operation_id=operation_id,
|
|
request_method="PUT",
|
|
request_url=upload_url,
|
|
response_status_code=resp.status,
|
|
response_headers=dict(resp.headers),
|
|
response_content="File uploaded successfully.",
|
|
)
|
|
except Exception as e:
|
|
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
|
return
|
|
except asyncio.CancelledError:
|
|
raise ProcessingInterrupted("Task cancelled") from None
|
|
except (aiohttp.ClientError, OSError) as e:
|
|
if attempt <= max_retries:
|
|
with contextlib.suppress(Exception):
|
|
request_logger.log_request_response(
|
|
operation_id=operation_id,
|
|
request_method="PUT",
|
|
request_url=upload_url,
|
|
request_headers=headers or None,
|
|
request_data=f"[File data {len(data)} bytes]",
|
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
|
)
|
|
await sleep_with_interrupt(
|
|
delay,
|
|
cls,
|
|
wait_label,
|
|
start_ts,
|
|
None,
|
|
display_callback=_display_time_progress if wait_label else None,
|
|
)
|
|
delay *= retry_backoff
|
|
continue
|
|
|
|
diag = await _diagnose_connectivity()
|
|
if not diag["internet_accessible"]:
|
|
raise LocalNetworkError(
|
|
"Unable to connect to the network. Please check your internet connection and try again."
|
|
) from e
|
|
raise ApiServerError("The API service appears unreachable at this time.") from e
|
|
finally:
|
|
stop_evt.set()
|
|
if monitor_task:
|
|
monitor_task.cancel()
|
|
with contextlib.suppress(Exception):
|
|
await monitor_task
|
|
if sess:
|
|
with contextlib.suppress(Exception):
|
|
await sess.close()
|
|
|
|
|
|
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
|
|
try:
|
|
parsed = urlparse(url)
|
|
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
|
|
except Exception:
|
|
slug = "upload"
|
|
return f"{method}_{slug}_{op_uuid}_try{attempt}"
|