mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
from __future__ import annotations
|
||
import aiohttp
|
||
import mimetypes
|
||
from typing import Union
|
||
from server import PromptServer
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
import torch
|
||
import base64
|
||
from io import BytesIO
|
||
|
||
|
||
async def validate_and_cast_response(
|
||
response, timeout: int = None, node_id: Union[str, None] = None
|
||
) -> torch.Tensor:
|
||
"""Validates and casts a response to a torch.Tensor.
|
||
|
||
Args:
|
||
response: The response to validate and cast.
|
||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||
|
||
Returns:
|
||
A torch.Tensor representing the image (1, H, W, C).
|
||
|
||
Raises:
|
||
ValueError: If the response is not valid.
|
||
"""
|
||
# validate raw JSON response
|
||
data = response.data
|
||
if not data or len(data) == 0:
|
||
raise ValueError("No images returned from API endpoint")
|
||
|
||
# Initialize list to store image tensors
|
||
image_tensors: list[torch.Tensor] = []
|
||
|
||
# Process each image in the data array
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||
for img_data in data:
|
||
img_bytes: bytes
|
||
if img_data.b64_json:
|
||
img_bytes = base64.b64decode(img_data.b64_json)
|
||
elif img_data.url:
|
||
if node_id:
|
||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||
async with session.get(img_data.url) as resp:
|
||
if resp.status != 200:
|
||
raise ValueError("Failed to download generated image")
|
||
img_bytes = await resp.read()
|
||
else:
|
||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||
|
||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||
image_tensors.append(torch.from_numpy(arr))
|
||
|
||
return torch.stack(image_tensors, dim=0)
|
||
|
||
|
||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||
"""Converts a text file to a base64 string."""
|
||
with open(filepath, "rb") as f:
|
||
file_content = f.read()
|
||
return base64.b64encode(file_content).decode("utf-8")
|
||
|
||
|
||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||
"""Converts a text file to a data URI."""
|
||
base64_string = text_filepath_to_base64_string(filepath)
|
||
mime_type, _ = mimetypes.guess_type(filepath)
|
||
if mime_type is None:
|
||
mime_type = "application/octet-stream"
|
||
return f"data:{mime_type};base64,{base64_string}"
|