mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
231 lines
8.4 KiB
Python
231 lines
8.4 KiB
Python
import logging
|
||
from typing import Optional
|
||
|
||
import torch
|
||
|
||
from comfy_api.input.video_types import VideoInput
|
||
from comfy_api.latest import Input
|
||
|
||
|
||
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||
if len(image.shape) == 4:
|
||
return image.shape[1], image.shape[2]
|
||
elif len(image.shape) == 3:
|
||
return image.shape[0], image.shape[1]
|
||
else:
|
||
raise ValueError("Invalid image tensor shape.")
|
||
|
||
|
||
def validate_image_dimensions(
|
||
image: torch.Tensor,
|
||
min_width: Optional[int] = None,
|
||
max_width: Optional[int] = None,
|
||
min_height: Optional[int] = None,
|
||
max_height: Optional[int] = None,
|
||
):
|
||
height, width = get_image_dimensions(image)
|
||
|
||
if min_width is not None and width < min_width:
|
||
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
|
||
if max_width is not None and width > max_width:
|
||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||
if min_height is not None and height < min_height:
|
||
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
|
||
if max_height is not None and height > max_height:
|
||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||
|
||
|
||
def validate_image_aspect_ratio(
|
||
image: torch.Tensor,
|
||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||
*,
|
||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||
) -> float:
|
||
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
|
||
w, h = get_image_dimensions(image)
|
||
if w <= 0 or h <= 0:
|
||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||
ar = w / h
|
||
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||
return ar
|
||
|
||
|
||
def validate_images_aspect_ratio_closeness(
|
||
first_image: torch.Tensor,
|
||
second_image: torch.Tensor,
|
||
min_rel: float, # e.g. 0.8
|
||
max_rel: float, # e.g. 1.25
|
||
*,
|
||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||
) -> float:
|
||
"""
|
||
Validates that the two images' aspect ratios are 'close'.
|
||
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
|
||
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
|
||
|
||
Returns the computed closeness factor C.
|
||
"""
|
||
w1, h1 = get_image_dimensions(first_image)
|
||
w2, h2 = get_image_dimensions(second_image)
|
||
if min(w1, h1, w2, h2) <= 0:
|
||
raise ValueError("Invalid image dimensions")
|
||
ar1 = w1 / h1
|
||
ar2 = w2 / h2
|
||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||
limit = max(max_rel, 1.0 / min_rel)
|
||
if (closeness >= limit) if strict else (closeness > limit):
|
||
raise ValueError(
|
||
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
|
||
f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})."
|
||
)
|
||
return closeness
|
||
|
||
|
||
def validate_aspect_ratio_string(
|
||
aspect_ratio: str,
|
||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||
*,
|
||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||
) -> float:
|
||
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
|
||
ar = _parse_aspect_ratio_string(aspect_ratio)
|
||
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||
return ar
|
||
|
||
|
||
def validate_video_dimensions(
|
||
video: Input.Video,
|
||
min_width: Optional[int] = None,
|
||
max_width: Optional[int] = None,
|
||
min_height: Optional[int] = None,
|
||
max_height: Optional[int] = None,
|
||
):
|
||
try:
|
||
width, height = video.get_dimensions()
|
||
except Exception as e:
|
||
logging.error("Error getting dimensions of video: %s", e)
|
||
return
|
||
|
||
if min_width is not None and width < min_width:
|
||
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
|
||
if max_width is not None and width > max_width:
|
||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||
if min_height is not None and height < min_height:
|
||
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
|
||
if max_height is not None and height > max_height:
|
||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||
|
||
|
||
def validate_video_duration(
|
||
video: Input.Video,
|
||
min_duration: Optional[float] = None,
|
||
max_duration: Optional[float] = None,
|
||
):
|
||
try:
|
||
duration = video.get_duration()
|
||
except Exception as e:
|
||
logging.error("Error getting duration of video: %s", e)
|
||
return
|
||
|
||
epsilon = 0.0001
|
||
if min_duration is not None and min_duration - epsilon > duration:
|
||
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
|
||
if max_duration is not None and duration > max_duration + epsilon:
|
||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||
|
||
|
||
def get_number_of_images(images):
|
||
if isinstance(images, torch.Tensor):
|
||
return images.shape[0] if images.ndim >= 4 else 1
|
||
return len(images)
|
||
|
||
|
||
def validate_audio_duration(
|
||
audio: Input.Audio,
|
||
min_duration: Optional[float] = None,
|
||
max_duration: Optional[float] = None,
|
||
) -> None:
|
||
sr = int(audio["sample_rate"])
|
||
dur = int(audio["waveform"].shape[-1]) / sr
|
||
eps = 1.0 / sr
|
||
if min_duration is not None and dur + eps < min_duration:
|
||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||
if max_duration is not None and dur - eps > max_duration:
|
||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||
|
||
|
||
def validate_string(
|
||
string: str,
|
||
strip_whitespace=True,
|
||
field_name="prompt",
|
||
min_length=None,
|
||
max_length=None,
|
||
):
|
||
if string is None:
|
||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||
if strip_whitespace:
|
||
string = string.strip()
|
||
if min_length and len(string) < min_length:
|
||
raise Exception(
|
||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||
)
|
||
if max_length and len(string) > max_length:
|
||
raise Exception(
|
||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||
)
|
||
|
||
|
||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||
"""Validates video container format is MP4."""
|
||
container_format = video.get_container_format()
|
||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||
|
||
|
||
def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||
a, b = r
|
||
if a <= 0 or b <= 0:
|
||
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
|
||
return a / b
|
||
|
||
|
||
def _assert_ratio_bounds(
|
||
ar: float,
|
||
*,
|
||
min_ratio: Optional[tuple[float, float]] = None,
|
||
max_ratio: Optional[tuple[float, float]] = None,
|
||
strict: bool = True,
|
||
) -> None:
|
||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
|
||
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
|
||
|
||
if lo is not None and hi is not None and lo > hi:
|
||
lo, hi = hi, lo # normalize order if caller swapped them
|
||
|
||
if lo is not None:
|
||
if (ar <= lo) if strict else (ar < lo):
|
||
op = "<" if strict else "≤"
|
||
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
|
||
if hi is not None:
|
||
if (ar >= hi) if strict else (ar > hi):
|
||
op = "<" if strict else "≤"
|
||
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
|
||
|
||
|
||
def _parse_aspect_ratio_string(ar_str: str) -> float:
|
||
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
|
||
parts = ar_str.split(":")
|
||
if len(parts) != 2:
|
||
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
|
||
try:
|
||
a = int(parts[0].strip())
|
||
b = int(parts[1].strip())
|
||
except ValueError as exc:
|
||
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
|
||
if a <= 0 or b <= 0:
|
||
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
|
||
return a / b
|