Update image_nodes.py

This commit is contained in:
kijai 2025-02-16 19:12:42 +02:00
parent 56979210c7
commit 095c8d4b52

View File

@ -1098,13 +1098,14 @@ class ImagePrepForICLora:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ("IMAGE",), "reference_image": ("IMAGE",),
"output_width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "output_width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"output_height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "output_height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"border_width": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}), "border_width": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 1}),
}, },
"optional": { "optional": {
"mask": ("MASK",), "latent_image": ("IMAGE",),
"reference_mask": ("MASK",),
} }
} }
@ -1113,17 +1114,19 @@ class ImagePrepForICLora:
CATEGORY = "image" CATEGORY = "image"
def expand_image(self, image, output_width, output_height, border_width, mask=None): def expand_image(self, reference_image, output_width, output_height, border_width, latent_image=None, reference_mask=None):
if mask is not None:
if torch.allclose(mask, torch.zeros_like(mask)): if reference_mask is not None:
if torch.allclose(reference_mask, torch.zeros_like(reference_mask)):
print("Warning: The incoming mask is fully black. Handling it as None.") print("Warning: The incoming mask is fully black. Handling it as None.")
mask = None reference_mask = None
image = reference_image
B, H, W, C = image.size() B, H, W, C = image.size()
# Handle mask # Handle mask
if mask is not None: if reference_mask is not None:
resized_mask = torch.nn.functional.interpolate( resized_mask = torch.nn.functional.interpolate(
mask.unsqueeze(1), reference_mask.unsqueeze(1),
size=(image.shape[1], image.shape[2]), size=(image.shape[1], image.shape[2]),
mode='nearest' mode='nearest'
).squeeze(1) ).squeeze(1)
@ -1137,15 +1140,19 @@ class ImagePrepForICLora:
resized_image = common_upscale(image.movedim(-1,1), new_width, output_height, "lanczos", "disabled").movedim(1,-1) resized_image = common_upscale(image.movedim(-1,1), new_width, output_height, "lanczos", "disabled").movedim(1,-1)
# Create padded image # Create padded image
empty_image = torch.zeros((B, output_height, output_width, C), device=image.device) if latent_image is None:
pad_image = torch.zeros((B, output_height, output_width, C), device=image.device)
else:
resized_latent_image = common_upscale(latent_image.movedim(-1,1), output_width, output_height, "lanczos", "disabled").movedim(1,-1)
pad_image = resized_latent_image
if border_width > 0: if border_width > 0:
border = torch.zeros((B, output_height, border_width, C), device=image.device) border = torch.zeros((B, output_height, border_width, C), device=image.device)
padded_image = torch.cat((resized_image, border, empty_image), dim=2) padded_image = torch.cat((resized_image, border, pad_image), dim=2)
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, :new_width + border_width] = 0 padded_mask[:, :, :new_width + border_width] = 0
else: else:
padded_image = torch.cat((resized_image, empty_image), dim=2) padded_image = torch.cat((resized_image, pad_image), dim=2)
padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device) padded_mask = torch.ones((B, padded_image.shape[1], padded_image.shape[2]), device=image.device)
padded_mask[:, :, :new_width] = 0 padded_mask[:, :, :new_width] = 0