diff --git a/nodes/image_nodes.py b/nodes/image_nodes.py index c4d1590..c409ac5 100644 --- a/nodes/image_nodes.py +++ b/nodes/image_nodes.py @@ -1322,19 +1322,22 @@ class TransitionImagesMulti: "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out", "bounce", "elastic", "glitchy", "exponential_ease_out"],), "transition_type": (["horizontal slide", "vertical slide", "box", "circle", "horizontal bar", "vertical bar", "horizontal door", "vertical door", "fade"],), "transitioning_frames": ("INT", {"default": 1,"min": 0, "max": 4096, "step": 1}), + "blur_radius": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 100.0, "step": 0.1}), "device": (["CPU", "GPU"], {"default": "CPU"}), }, } #transitions from matteo's essential nodes - def transition(self, inputcount, transitioning_frames, transition_type, interpolation, device, **kwargs): + def transition(self, inputcount, transitioning_frames, transition_type, interpolation, device, blur_radius, **kwargs): - device = model_management.get_torch_device() + gpu = model_management.get_torch_device() - def wipe(images_1, images_2, alpha, transition_type): + def wipe(images_1, images_2, alpha, transition_type, blur_radius): width = images_1.shape[1] height = images_1.shape[0] - mask = torch.zeros_like(images_1) + + mask = torch.zeros_like(images_1, device=images_1.device) + alpha = alpha.item() if "horizontal slide" in transition_type: @@ -1383,6 +1386,8 @@ class TransitionImagesMulti: elif "fade" in transition_type: mask[:, :, :] = alpha + mask = gaussian_blur(mask, blur_radius) + return images_1 * (1 - mask) + images_2 * mask def ease_in(t): @@ -1402,6 +1407,24 @@ class TransitionImagesMulti: return t + 0.1 * math.sin(40 * t) def exponential_ease_out(t): return 1 - (1 - t) ** 4 + + def gaussian_blur(mask, blur_radius): + print(mask.device) + if blur_radius > 0: + kernel_size = int(blur_radius * 2) + 1 + if kernel_size % 2 == 0: + kernel_size += 1 # Ensure kernel size is odd + sigma = blur_radius / 3 + x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32) + x = torch.exp(-0.5 * (x / sigma) ** 2) + kernel1d = x / x.sum() + kernel2d = kernel1d[:, None] * kernel1d[None, :] + kernel2d = kernel2d.to(mask.device) + kernel2d = kernel2d.expand(mask.shape[2], 1, kernel2d.shape[0], kernel2d.shape[1]) + mask = mask.permute(2, 0, 1).unsqueeze(0) # Change to [C, H, W] and add batch dimension + mask = F.conv2d(mask, kernel2d, padding=kernel_size // 2, groups=mask.shape[1]) + mask = mask.squeeze(0).permute(1, 2, 0) # Change back to [H, W, C] + return mask easing_functions = { "linear": lambda t: t, @@ -1433,17 +1456,17 @@ class TransitionImagesMulti: last_frame_image_1 = image_1[-1] first_frame_image_2 = new_image[0] if device == "GPU": - last_frame_image_1 = last_frame_image_1.to(device) - first_frame_image_2 = first_frame_image_2.to(device) + last_frame_image_1 = last_frame_image_1.to(gpu) + first_frame_image_2 = first_frame_image_2.to(gpu) for frame in range(transitioning_frames): t = frame / (transitioning_frames - 1) alpha = easing_function(t) alpha_tensor = torch.tensor(alpha, dtype=last_frame_image_1.dtype, device=last_frame_image_1.device) - frame_image = wipe(last_frame_image_1, first_frame_image_2, alpha_tensor, transition_type) + frame_image = wipe(last_frame_image_1, first_frame_image_2, alpha_tensor, transition_type, blur_radius) frames.append(frame_image) - frames = torch.stack(frames) + frames = torch.stack(frames).cpu() image_1 = torch.cat((image_1, frames, new_image), dim=0) return image_1.cpu(),