From 2a733b9b7d21f8e2d39fec4be874ef83f26c3652 Mon Sep 17 00:00:00 2001 From: Kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:18:42 +0200 Subject: [PATCH] Add StableZero123_BatchSchedule -node --- nodes.py | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 167 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index 4d407ca..83a899c 100644 --- a/nodes.py +++ b/nodes.py @@ -307,7 +307,7 @@ class CreateFadeMaskAdvanced: def INPUT_TYPES(s): return { "required": { - "points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n16:(0.0)\n", "multiline": True}), + "points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}), "invert": ("BOOLEAN", {"default": False}), "frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}), "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), @@ -2575,7 +2575,7 @@ class AddLabel: "font_size": ("INT", {"default": 32, "min": 0, "max": 4096, "step": 1}), "font_color": ("STRING", {"default": "white"}), "label_color": ("STRING", {"default": "black"}), - "font_path": ("STRING", {"default": "fonts\\TTNorms-Black.otf"}), + "font": ("STRING", {"default": "TTNorms-Black.otf"}), "text": ("STRING", {"default": "Text"}), "direction": ( [ 'up', @@ -2593,12 +2593,12 @@ class AddLabel: CATEGORY = "KJNodes" - def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font_path, direction): + def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction): batch_size = image.shape[0] width = image.shape[2] - if font_path == "fonts\\TTNorms-Black.otf": #I don't know why relative path won't work otherwise... - font_path = os.path.join(script_dir, font_path) + if font == "TTNorms-Black.otf": + font_path = os.path.join(script_dir, "fonts", "TTNorms-Black.otf") label_image = Image.new("RGB", (width, height), label_color) draw = ImageDraw.Draw(label_image) @@ -2732,6 +2732,164 @@ class GenerateNoise: noise = noise[0].repeat(batch_size, 1, 1, 1) return ({"samples":noise}, ) +def camera_embeddings(elevation, azimuth): + elevation = torch.as_tensor([elevation]) + azimuth = torch.as_tensor([azimuth]) + embeddings = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), + ], dim=-1).unsqueeze(1) + + return embeddings + +def interpolate_angle(start, end, fraction): + # Calculate the difference in angles and adjust for wraparound if necessary + diff = (end - start + 540) % 360 - 180 + # Apply fraction to the difference + interpolated = start + fraction * diff + # Normalize the result to be within the range of -180 to 180 + return (interpolated + 180) % 360 - 180 + +class StableZero123_BatchSchedule: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), + "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}), + "elevation_points_string": ("STRING", {"default": "0:(0.0),\n7:(0.0),\n15:(0.0)\n", "multiline": True}), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "KJNodes" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, azimuth_points_string, elevation_points_string, interpolation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + def ease_in(t): + return t * t + def ease_out(t): + return 1 - (1 - t) * (1 - t) + def ease_in_out(t): + return 3 * t * t - 2 * t * t * t + + # Parse the azimuth input string into a list of tuples + azimuth_points = [] + azimuth_points_string = azimuth_points_string.rstrip(',\n') + for point_str in azimuth_points_string.split(','): + frame_str, azimuth_str = point_str.split(':') + frame = int(frame_str.strip()) + azimuth = float(azimuth_str.strip()[1:-1]) + azimuth_points.append((frame, azimuth)) + # Sort the points by frame number + azimuth_points.sort(key=lambda x: x[0]) + + # Parse the elevation input string into a list of tuples + elevation_points = [] + elevation_points_string = elevation_points_string.rstrip(',\n') + for point_str in elevation_points_string.split(','): + frame_str, elevation_str = point_str.split(':') + frame = int(frame_str.strip()) + elevation_val = float(elevation_str.strip()[1:-1]) + elevation_points.append((frame, elevation_val)) + # Sort the points by frame number + elevation_points.sort(key=lambda x: x[0]) + + # Index of the next point to interpolate towards + next_point = 1 + next_elevation_point = 1 + + positive_cond_out = [] + positive_pooled_out = [] + negative_cond_out = [] + negative_pooled_out = [] + + #azimuth interpolation + for i in range(batch_size): + # Find the interpolated azimuth for the current frame + while next_point < len(azimuth_points) and i >= azimuth_points[next_point][0]: + next_point += 1 + # If next_point is equal to the length of points, we've gone past the last point + if next_point == len(azimuth_points): + next_point -= 1 # Set next_point to the last index of points + prev_point = max(next_point - 1, 0) # Ensure prev_point is not less than 0 + + # Calculate fraction + if azimuth_points[next_point][0] != azimuth_points[prev_point][0]: # Prevent division by zero + fraction = (i - azimuth_points[prev_point][0]) / (azimuth_points[next_point][0] - azimuth_points[prev_point][0]) + if interpolation == "ease_in": + fraction = ease_in(fraction) + elif interpolation == "ease_out": + fraction = ease_out(fraction) + elif interpolation == "ease_in_out": + fraction = ease_in_out(fraction) + + # Use the new interpolate_angle function + interpolated_azimuth = interpolate_angle(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction) + else: + interpolated_azimuth = azimuth_points[prev_point][1] + + # Interpolate the elevation + next_elevation_point = 1 + while next_elevation_point < len(elevation_points) and i >= elevation_points[next_elevation_point][0]: + next_elevation_point += 1 + if next_elevation_point == len(elevation_points): + next_elevation_point -= 1 + prev_elevation_point = max(next_elevation_point - 1, 0) + + if elevation_points[next_elevation_point][0] != elevation_points[prev_elevation_point][0]: + fraction = (i - elevation_points[prev_elevation_point][0]) / (elevation_points[next_elevation_point][0] - elevation_points[prev_elevation_point][0]) + if interpolation == "ease_in": + fraction = ease_in(fraction) + elif interpolation == "ease_out": + fraction = ease_out(fraction) + elif interpolation == "ease_in_out": + fraction = ease_in_out(fraction) + + interpolated_elevation = interpolate_angle(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction) + else: + interpolated_elevation = elevation_points[prev_elevation_point][1] + + cam_embeds = camera_embeddings(interpolated_elevation, interpolated_azimuth) + cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + + positive_pooled_out.append(t) + positive_cond_out.append(cond) + negative_pooled_out.append(torch.zeros_like(t)) + negative_cond_out.append(torch.zeros_like(pooled)) + + # Concatenate the conditions and pooled outputs + final_positive_cond = torch.cat(positive_cond_out, dim=0) + final_positive_pooled = torch.cat(positive_pooled_out, dim=0) + final_negative_cond = torch.cat(negative_cond_out, dim=0) + final_negative_pooled = torch.cat(negative_pooled_out, dim=0) + + # Structure the final output + final_positive = [[final_positive_cond, {"concat_latent_image": final_positive_pooled}]] + final_negative = [[final_negative_cond, {"concat_latent_image": final_negative_pooled}]] + + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (final_positive, final_negative, {"samples": latent}) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -2783,7 +2941,8 @@ NODE_CLASS_MAPPINGS = { "AddLabel": AddLabel, "ReferenceOnlySimple3": ReferenceOnlySimple3, "SoundReactive": SoundReactive, - "GenerateNoise": GenerateNoise + "GenerateNoise": GenerateNoise, + "StableZero123_BatchSchedule": StableZero123_BatchSchedule, } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -2835,5 +2994,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "AddLabel": "AddLabel", "ReferenceOnlySimple3": "ReferenceOnlySimple3", "SoundReactive": "SoundReactive", - "GenerateNoise": "GenerateNoise" + "GenerateNoise": "GenerateNoise", + "StableZero123_BatchSchedule": "StableZero123_BatchSchedule", } \ No newline at end of file