Add CogVideoXFunResizeToClosestBucket

This commit is contained in:
kijai 2024-10-30 22:19:26 +02:00
parent 5ba9b1d634
commit 133b42eb4f

View File

@ -838,6 +838,40 @@ class CogVideoDecode:
return (video,)
class CogVideoXFunResizeToClosestBucket:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
def INPUT_TYPES(s):
return {"required": {
"images": ("IMAGE", ),
"base_resolution": ("INT", {"min": 64, "max": 1280, "step": 64, "default": 512, "tooltip": "Base resolution, closest training data bucket resolution is chosen based on the selection."}),
"upscale_method": (s.upscale_methods, {"default": "lanczos", "tooltip": "Upscale method to use"}),
"crop": (["disabled","center"],),
},
}
RETURN_TYPES = ("IMAGE", "INT", "INT")
RETURN_NAMES = ("images", "width", "height")
FUNCTION = "resize"
CATEGORY = "CogVideoWrapper"
def resize(self, images, base_resolution, upscale_method, crop):
from comfy.utils import common_upscale
B, H, W, C = images.shape
# Count most suitable height and width
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
closest_size, closest_ratio = get_closest_ratio(H, W, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
log.info(f"Closest bucket size: {width}x{height}")
resized_images = images.clone().movedim(-1,1)
resized_images = common_upscale(resized_images, width, height, upscale_method, crop)
resized_images = resized_images.movedim(1,-1)
return (resized_images, width, height)
class CogVideoXFunSampler:
@classmethod
def INPUT_TYPES(s):
@ -1266,7 +1300,8 @@ NODE_CLASS_MAPPINGS = {
"CogVideoControlNet": CogVideoControlNet,
"ToraEncodeTrajectory": ToraEncodeTrajectory,
"ToraEncodeOpticalFlow": ToraEncodeOpticalFlow,
"CogVideoXFasterCache": CogVideoXFasterCache
"CogVideoXFasterCache": CogVideoXFasterCache,
"CogVideoXFunResizeToClosestBucket": CogVideoXFunResizeToClosestBucket
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoSampler": "CogVideo Sampler",
@ -1286,5 +1321,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoContextOptions": "CogVideo Context Options",
"ToraEncodeTrajectory": "Tora Encode Trajectory",
"ToraEncodeOpticalFlow": "Tora Encode OpticalFlow",
"CogVideoXFasterCache": "CogVideoX FasterCache"
"CogVideoXFasterCache": "CogVideoX FasterCache",
"CogVideoXFunResizeToClosestBucket": "CogVideoXFun ResizeToClosestBucket"
}