mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
454 lines
16 KiB
Python
454 lines
16 KiB
Python
import base64
|
|
import logging
|
|
import math
|
|
import uuid
|
|
from io import BytesIO
|
|
from typing import Optional
|
|
|
|
import av
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from comfy.utils import common_upscale
|
|
from comfy_api.latest import Input, InputImpl
|
|
from comfy_api.util import VideoContainer, VideoCodec
|
|
|
|
from ._helpers import mimetype_to_extension
|
|
|
|
|
|
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
|
"""Converts image data from BytesIO to a torch.Tensor.
|
|
|
|
Args:
|
|
image_bytesio: BytesIO object containing the image data.
|
|
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
|
|
|
Returns:
|
|
A torch.Tensor representing the image (1, H, W, C).
|
|
|
|
Raises:
|
|
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
|
ValueError: If the specified mode is invalid.
|
|
"""
|
|
image = Image.open(image_bytesio)
|
|
image = image.convert(mode)
|
|
image_array = np.array(image).astype(np.float32) / 255.0
|
|
return torch.from_numpy(image_array).unsqueeze(0)
|
|
|
|
|
|
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Converts a pair of image tensors to a batch tensor.
|
|
If the images are not the same size, the smaller image is resized to
|
|
match the larger image.
|
|
"""
|
|
if image1.shape[1:] != image2.shape[1:]:
|
|
image2 = common_upscale(
|
|
image2.movedim(-1, 1),
|
|
image1.shape[2],
|
|
image1.shape[1],
|
|
"bilinear",
|
|
"center",
|
|
).movedim(1, -1)
|
|
return torch.cat((image1, image2), dim=0)
|
|
|
|
|
|
def tensor_to_bytesio(
|
|
image: torch.Tensor,
|
|
name: Optional[str] = None,
|
|
total_pixels: int = 2048 * 2048,
|
|
mime_type: str = "image/png",
|
|
) -> BytesIO:
|
|
"""Converts a torch.Tensor image to a named BytesIO object.
|
|
|
|
Args:
|
|
image: Input torch.Tensor image.
|
|
name: Optional filename for the BytesIO object.
|
|
total_pixels: Maximum total pixels for potential downscaling.
|
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
|
|
|
Returns:
|
|
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
|
"""
|
|
if not mime_type:
|
|
mime_type = "image/png"
|
|
|
|
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
|
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
|
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
|
return img_binary
|
|
|
|
|
|
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
|
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
|
if len(image.shape) > 3:
|
|
image = image[0]
|
|
# TODO: remove alpha if not allowed and present
|
|
input_tensor = image.cpu()
|
|
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
|
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
|
img = Image.fromarray(image_np)
|
|
return img
|
|
|
|
|
|
def tensor_to_base64_string(
|
|
image_tensor: torch.Tensor,
|
|
total_pixels: int = 2048 * 2048,
|
|
mime_type: str = "image/png",
|
|
) -> str:
|
|
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
|
|
|
Args:
|
|
image_tensor: Input torch.Tensor image.
|
|
total_pixels: Maximum total pixels for potential downscaling.
|
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
|
|
|
Returns:
|
|
Base64 encoded string of the image.
|
|
"""
|
|
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
|
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
|
|
img_bytes = img_byte_arr.getvalue()
|
|
# Encode bytes to base64 string
|
|
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
|
return base64_encoded_string
|
|
|
|
|
|
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
|
"""Converts a PIL Image to a BytesIO object."""
|
|
if not mime_type:
|
|
mime_type = "image/png"
|
|
|
|
img_byte_arr = BytesIO()
|
|
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
|
pil_format = mime_type.split("/")[-1].upper()
|
|
if pil_format == "JPG":
|
|
pil_format = "JPEG"
|
|
img.save(img_byte_arr, format=pil_format)
|
|
img_byte_arr.seek(0)
|
|
return img_byte_arr
|
|
|
|
|
|
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
|
"""Downscale input image tensor to roughly the specified total pixels."""
|
|
samples = image.movedim(-1, 1)
|
|
total = int(total_pixels)
|
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
|
if scale_by >= 1:
|
|
return image
|
|
width = round(samples.shape[3] * scale_by)
|
|
height = round(samples.shape[2] * scale_by)
|
|
|
|
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
|
s = s.movedim(1, -1)
|
|
return s
|
|
|
|
|
|
def tensor_to_data_uri(
|
|
image_tensor: torch.Tensor,
|
|
total_pixels: int = 2048 * 2048,
|
|
mime_type: str = "image/png",
|
|
) -> str:
|
|
"""Converts a tensor image to a Data URI string.
|
|
|
|
Args:
|
|
image_tensor: Input torch.Tensor image.
|
|
total_pixels: Maximum total pixels for potential downscaling.
|
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
|
|
|
Returns:
|
|
Data URI string (e.g., 'data:image/png;base64,...').
|
|
"""
|
|
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
|
return f"data:{mime_type};base64,{base64_string}"
|
|
|
|
|
|
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
|
|
"""Converts an audio input to a base64 string."""
|
|
sample_rate: int = audio["sample_rate"]
|
|
waveform: torch.Tensor = audio["waveform"]
|
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
|
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
|
audio_bytes = audio_bytes_io.getvalue()
|
|
return base64.b64encode(audio_bytes).decode("utf-8")
|
|
|
|
|
|
def video_to_base64_string(
|
|
video: Input.Video,
|
|
container_format: VideoContainer = None,
|
|
codec: VideoCodec = None
|
|
) -> str:
|
|
"""
|
|
Converts a video input to a base64 string.
|
|
|
|
Args:
|
|
video: The video input to convert
|
|
container_format: Optional container format to use (defaults to video.container if available)
|
|
codec: Optional codec to use (defaults to video.codec if available)
|
|
"""
|
|
video_bytes_io = BytesIO()
|
|
|
|
# Use provided format/codec if specified, otherwise use video's own if available
|
|
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
|
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
|
|
|
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
|
video_bytes_io.seek(0)
|
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
|
|
|
|
|
def audio_ndarray_to_bytesio(
|
|
audio_data_np: np.ndarray,
|
|
sample_rate: int,
|
|
container_format: str = "mp4",
|
|
codec_name: str = "aac",
|
|
) -> BytesIO:
|
|
"""
|
|
Encodes a numpy array of audio data into a BytesIO object.
|
|
"""
|
|
audio_bytes_io = BytesIO()
|
|
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
|
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
|
frame = av.AudioFrame.from_ndarray(
|
|
audio_data_np,
|
|
format="fltp",
|
|
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
|
)
|
|
frame.sample_rate = sample_rate
|
|
frame.pts = 0
|
|
|
|
for packet in audio_stream.encode(frame):
|
|
output_container.mux(packet)
|
|
|
|
# Flush stream
|
|
for packet in audio_stream.encode(None):
|
|
output_container.mux(packet)
|
|
|
|
audio_bytes_io.seek(0)
|
|
return audio_bytes_io
|
|
|
|
|
|
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
|
"""
|
|
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
|
|
|
Args:
|
|
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
|
|
|
Returns:
|
|
Contiguous numpy array of the audio waveform. If the audio was batched,
|
|
the first item is taken.
|
|
"""
|
|
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
|
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
|
|
|
# If batch is > 1, take first item
|
|
if waveform.shape[0] > 1:
|
|
waveform = waveform[0]
|
|
|
|
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
|
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
|
if audio_data_np.dtype != np.float32:
|
|
audio_data_np = audio_data_np.astype(np.float32)
|
|
|
|
return audio_data_np
|
|
|
|
|
|
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
|
|
waveform = audio["waveform"].cpu()
|
|
|
|
output_buffer = BytesIO()
|
|
output_container = av.open(output_buffer, mode="w", format="mp3")
|
|
|
|
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
|
out_stream.bit_rate = 320000
|
|
|
|
frame = av.AudioFrame.from_ndarray(
|
|
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
|
format="flt",
|
|
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
|
)
|
|
frame.sample_rate = audio["sample_rate"]
|
|
frame.pts = 0
|
|
output_container.mux(out_stream.encode(frame))
|
|
output_container.mux(out_stream.encode(None))
|
|
output_container.close()
|
|
output_buffer.seek(0)
|
|
return output_buffer
|
|
|
|
|
|
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
|
"""
|
|
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
|
using av to avoid loading entire video into memory.
|
|
|
|
Args:
|
|
video: Input video to trim
|
|
duration_sec: Duration in seconds to keep from the beginning
|
|
|
|
Returns:
|
|
VideoFromFile object that owns the output buffer
|
|
"""
|
|
output_buffer = BytesIO()
|
|
input_container = None
|
|
output_container = None
|
|
|
|
try:
|
|
# Get the stream source - this avoids loading entire video into memory
|
|
# when the source is already a file path
|
|
input_source = video.get_stream_source()
|
|
|
|
# Open containers
|
|
input_container = av.open(input_source, mode="r")
|
|
output_container = av.open(output_buffer, mode="w", format="mp4")
|
|
|
|
# Set up output streams for re-encoding
|
|
video_stream = None
|
|
audio_stream = None
|
|
|
|
for stream in input_container.streams:
|
|
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
|
if isinstance(stream, av.VideoStream):
|
|
# Create output video stream with same parameters
|
|
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
|
|
video_stream.width = stream.width
|
|
video_stream.height = stream.height
|
|
video_stream.pix_fmt = "yuv420p"
|
|
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
|
|
elif isinstance(stream, av.AudioStream):
|
|
# Create output audio stream with same parameters
|
|
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
|
audio_stream.sample_rate = stream.sample_rate
|
|
audio_stream.layout = stream.layout
|
|
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
|
|
|
# Calculate target frame count that's divisible by 16
|
|
fps = input_container.streams.video[0].average_rate
|
|
estimated_frames = int(duration_sec * fps)
|
|
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
|
|
|
if target_frames == 0:
|
|
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
|
|
|
frame_count = 0
|
|
audio_frame_count = 0
|
|
|
|
# Decode and re-encode video frames
|
|
if video_stream:
|
|
for frame in input_container.decode(video=0):
|
|
if frame_count >= target_frames:
|
|
break
|
|
|
|
# Re-encode frame
|
|
for packet in video_stream.encode(frame):
|
|
output_container.mux(packet)
|
|
frame_count += 1
|
|
|
|
# Flush encoder
|
|
for packet in video_stream.encode():
|
|
output_container.mux(packet)
|
|
|
|
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
|
|
|
# Decode and re-encode audio frames
|
|
if audio_stream:
|
|
input_container.seek(0) # Reset to beginning for audio
|
|
for frame in input_container.decode(audio=0):
|
|
if frame.time >= duration_sec:
|
|
break
|
|
|
|
# Re-encode frame
|
|
for packet in audio_stream.encode(frame):
|
|
output_container.mux(packet)
|
|
audio_frame_count += 1
|
|
|
|
# Flush encoder
|
|
for packet in audio_stream.encode():
|
|
output_container.mux(packet)
|
|
|
|
logging.info("Encoded %s audio frames", audio_frame_count)
|
|
|
|
# Close containers
|
|
output_container.close()
|
|
input_container.close()
|
|
|
|
# Return as VideoFromFile using the buffer
|
|
output_buffer.seek(0)
|
|
return InputImpl.VideoFromFile(output_buffer)
|
|
|
|
except Exception as e:
|
|
# Clean up on error
|
|
if input_container is not None:
|
|
input_container.close()
|
|
if output_container is not None:
|
|
output_container.close()
|
|
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
|
|
|
|
|
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
|
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
|
if wav.dtype.is_floating_point:
|
|
return wav
|
|
elif wav.dtype == torch.int16:
|
|
return wav.float() / (2**15)
|
|
elif wav.dtype == torch.int32:
|
|
return wav.float() / (2**31)
|
|
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
|
|
|
|
|
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
|
"""
|
|
Decode any common audio container from bytes using PyAV and return
|
|
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
|
"""
|
|
with av.open(BytesIO(audio_bytes)) as af:
|
|
if not af.streams.audio:
|
|
raise ValueError("No audio stream found in response.")
|
|
stream = af.streams.audio[0]
|
|
|
|
in_sr = int(stream.codec_context.sample_rate)
|
|
out_sr = in_sr
|
|
|
|
frames: list[torch.Tensor] = []
|
|
n_channels = stream.channels or 1
|
|
|
|
for frame in af.decode(streams=stream.index):
|
|
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
|
buf = torch.from_numpy(arr)
|
|
if buf.ndim == 1:
|
|
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
|
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
|
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
|
elif buf.shape[0] != n_channels:
|
|
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
|
frames.append(buf)
|
|
|
|
if not frames:
|
|
raise ValueError("Decoded zero audio frames.")
|
|
|
|
wav = torch.cat(frames, dim=1) # [C, T]
|
|
wav = _f32_pcm(wav)
|
|
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
|
|
|
|
|
def resize_mask_to_image(
|
|
mask: torch.Tensor,
|
|
image: torch.Tensor,
|
|
upscale_method="nearest-exact",
|
|
crop="disabled",
|
|
allow_gradient=True,
|
|
add_channel_dim=False,
|
|
):
|
|
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
|
|
_, height, width, _ = image.shape
|
|
mask = mask.unsqueeze(-1)
|
|
mask = mask.movedim(-1, 1)
|
|
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
|
|
mask = mask.movedim(1, -1)
|
|
if not add_channel_dim:
|
|
mask = mask.squeeze(-1)
|
|
if not allow_gradient:
|
|
mask = (mask > 0.5).float()
|
|
return mask
|