From e227539a2a6dd938d069e8eef827c5e6ffa2afeb Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 28 Jan 2024 22:19:31 +0200 Subject: [PATCH] Update nodes.py --- nodes.py | 247 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 226 insertions(+), 21 deletions(-) diff --git a/nodes.py b/nodes.py index d2695c9..01e4cec 100644 --- a/nodes.py +++ b/nodes.py @@ -705,9 +705,10 @@ class GrowMaskWithBlur: pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius)) # Convert back to tensor out[idx] = pil2tensor(pil_image) - - blurred = torch.cat(out, dim=0) - return (blurred, 1.0 - blurred) + blurred = torch.cat(out, dim=0) + return (blurred, 1.0 - blurred) + else: + return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),) @@ -2195,7 +2196,7 @@ class OffsetMask: def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1, padding_mode="empty"): # Create duplicates of the mask batch - mask = mask.repeat(duplication_factor, 1, 1) + mask = mask.repeat(duplication_factor, 1, 1).clone() batch_size, height, width = mask.shape @@ -2275,9 +2276,7 @@ class WidgetToString: "widget_name": ("STRING", {"multiline": False}), "return_all": ("BOOLEAN", {"default": False}), }, - "optional": { - "source": (any, {}), - }, + "hidden": {"extra_pnginfo": "EXTRA_PNGINFO", "prompt": "PROMPT"}, } @@ -2327,17 +2326,18 @@ class CreateShapeMask: { "default": 'circle' }), - "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), - "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), - "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), - "size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), - "grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), - "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), - "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}), + "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), + "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), + "grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), + "frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), + "shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), }, } - def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, size, grow, shape): + def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, shape_width, shape_height, grow, shape): # Define the number of images in the batch batch_size = frames out = [] @@ -2347,12 +2347,13 @@ class CreateShapeMask: draw = ImageDraw.Draw(image) # Calculate the size for this frame and ensure it's not less than 0 - current_size = max(0, size + i*grow) + current_width = max(0, shape_width + i*grow) + current_height = max(0, shape_height + i*grow) if shape == 'circle' or shape == 'square': # Define the bounding box for the shape - left_up_point = (location_x - current_size // 2, location_y - current_size // 2) - right_down_point = (location_x + current_size // 2, location_y + current_size // 2) + left_up_point = (location_x - current_width // 2, location_y - current_height // 2) + right_down_point = (location_x + current_width // 2, location_y + current_height // 2) two_points = [left_up_point, right_down_point] if shape == 'circle': @@ -2362,9 +2363,9 @@ class CreateShapeMask: elif shape == 'triangle': # Define the points for the triangle - left_up_point = (location_x - current_size // 2, location_y + current_size // 2) # bottom left - right_down_point = (location_x + current_size // 2, location_y + current_size // 2) # bottom right - top_point = (location_x, location_y - current_size // 2) # top point + left_up_point = (location_x - current_width // 2, location_y + current_height // 2) # bottom left + right_down_point = (location_x + current_width // 2, location_y + current_height // 2) # bottom right + top_point = (location_x, location_y - current_height // 2) # top point draw.polygon([top_point, left_up_point, right_down_point], fill=color) image = pil2tensor(image) @@ -3073,6 +3074,204 @@ class ImageBatchRepeatInterleaving: repeated_images = torch.repeat_interleave(images, repeats=repeats, dim=0) return (repeated_images, ) +class NormalizedAmplitudeToMask: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "normalized_amp": ("NORMALIZED_AMPLITUDE",), + "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), + "frame_offset": ("INT", {"default": 0,"min": -255, "max": 255, "step": 1}), + "location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), + "location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}), + "size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}), + "shape": ( + [ + 'none', + 'circle', + 'square', + 'triangle', + ], + { + "default": 'none' + }), + "color": ( + [ + 'white', + 'amplitude', + ], + { + "default": 'amplitude' + }), + },} + + CATEGORY = "AudioScheduler/Amplitude" + + RETURN_TYPES = ("MASK",) + FUNCTION = "convert" + + def convert(self, normalized_amp, width, height, frame_offset, shape, location_x, location_y, size, color): + # Ensure normalized_amp is an array and within the range [0, 1] + normalized_amp = np.clip(normalized_amp, 0.0, 1.0) + + # Offset the amplitude values by rolling the array + normalized_amp = np.roll(normalized_amp, frame_offset) + + # Initialize an empty list to hold the image tensors + out = [] + # Iterate over each amplitude value to create an image + for amp in normalized_amp: + # Scale the amplitude value to cover the full range of grayscale values + if color == 'amplitude': + grayscale_value = int(amp * 255) + elif color == 'white': + grayscale_value = 255 + # Convert the grayscale value to an RGB format + gray_color = (grayscale_value, grayscale_value, grayscale_value) + finalsize = size * amp + + if shape == 'none': + shapeimage = Image.new("RGB", (width, height), gray_color) + else: + shapeimage = Image.new("RGB", (width, height), "black") + + draw = ImageDraw.Draw(shapeimage) + if shape == 'circle' or shape == 'square': + # Define the bounding box for the shape + left_up_point = (location_x - finalsize, location_y - finalsize) + right_down_point = (location_x + finalsize,location_y + finalsize) + two_points = [left_up_point, right_down_point] + + if shape == 'circle': + draw.ellipse(two_points, fill=gray_color) + elif shape == 'square': + draw.rectangle(two_points, fill=gray_color) + + elif shape == 'triangle': + # Define the points for the triangle + left_up_point = (location_x - finalsize, location_y + finalsize) # bottom left + right_down_point = (location_x + finalsize, location_y + finalsize) # bottom right + top_point = (location_x, location_y) # top point + draw.polygon([top_point, left_up_point, right_down_point], fill=gray_color) + + shapeimage = pil2tensor(shapeimage) + mask = shapeimage[:, :, :, 0] + out.append(mask) + + return (torch.cat(out, dim=0),) + +class OffsetMaskByNormalizedAmplitude: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "normalized_amp": ("NORMALIZED_AMPLITUDE",), + "mask": ("MASK",), + "x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), + "y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }), + "rotate": ("BOOLEAN", { "default": False }), + "angle_multiplier": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }), + } + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "offset" + CATEGORY = "KJNodes/masking" + + def offset(self, mask, x, y, angle_multiplier, rotate, normalized_amp): + + # Ensure normalized_amp is an array and within the range [0, 1] + offsetmask = mask.clone() + normalized_amp = np.clip(normalized_amp, 0.0, 1.0) + + batch_size, height, width = mask.shape + + if rotate: + for i in range(batch_size): + rotation_amp = int(normalized_amp[i] * (360 * angle_multiplier)) + rotation_angle = rotation_amp + offsetmask[i] = TF.rotate(offsetmask[i].unsqueeze(0), rotation_angle).squeeze(0) + if x != 0 or y != 0: + for i in range(batch_size): + offset_amp = normalized_amp[i] * 10 + shift_x = min(x*offset_amp, width-1) + shift_y = min(y*offset_amp, height-1) + if shift_x != 0: + offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_x), dims=1) + if shift_y != 0: + offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_y), dims=0) + + return offsetmask, + + +class ImageTransformByNormalizedAmplitude: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "normalized_amp": ("NORMALIZED_AMPLITUDE",), + "zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }), + "cumulative": ("BOOLEAN", { "default": False }), + "image": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "amptransform" + CATEGORY = "KJNodes" + + def amptransform(self, image, normalized_amp, zoom_scale, cumulative): + # Ensure normalized_amp is an array and within the range [0, 1] + normalized_amp = np.clip(normalized_amp, 0.0, 1.0) + transformed_images = [] + + # Initialize the cumulative zoom factor + prev_amp = 0.0 + + for i in range(image.shape[0]): + img = image[i] # Get the i-th image in the batch + amp = normalized_amp[i] # Get the corresponding amplitude value + + # Incrementally increase the cumulative zoom factor + if cumulative: + prev_amp += amp + amp += prev_amp + + # Convert the image tensor from BxHxWxC to CxHxW format expected by torchvision + img = img.permute(2, 0, 1) + + # Convert PyTorch tensor to PIL Image for processing + pil_img = TF.to_pil_image(img) + + # Calculate the crop size based on the amplitude + width, height = pil_img.size + crop_size = int(min(width, height) * (1 - amp * zoom_scale)) + crop_size = max(crop_size, 1) + + # Calculate the crop box coordinates (centered crop) + left = (width - crop_size) // 2 + top = (height - crop_size) // 2 + right = (width + crop_size) // 2 + bottom = (height + crop_size) // 2 + + # Crop and resize back to original size + cropped_img = TF.crop(pil_img, top, left, crop_size, crop_size) + resized_img = TF.resize(cropped_img, (height, width)) + + # Convert back to tensor in CxHxW format + tensor_img = TF.to_tensor(resized_img) + + # Convert the tensor back to BxHxWxC format + tensor_img = tensor_img.permute(1, 2, 0) + + # Add to the list + transformed_images.append(tensor_img) + + # Stack all transformed images into a batch + transformed_batch = torch.stack(transformed_images) + + return (transformed_batch,) + + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -3130,6 +3329,9 @@ NODE_CLASS_MAPPINGS = { "StableZero123_BatchSchedule": StableZero123_BatchSchedule, "GetImagesFromBatchIndexed": GetImagesFromBatchIndexed, "ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving, + "NormalizedAmplitudeToMask": NormalizedAmplitudeToMask, + "OffsetMaskByNormalizedAmplitude": OffsetMaskByNormalizedAmplitude, + "ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -3187,4 +3389,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "StableZero123_BatchSchedule": "StableZero123_BatchSchedule", "GetImagesFromBatchIndexed": "GetImagesFromBatchIndexed", "ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving", + "NormalizedAmplitudeToMask": "NormalizedAmplitudeToMask", + "OffsetMaskByNormalizedAmplitude": "OffsetMaskByNormalizedAmplitude", + "ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude" } \ No newline at end of file