Add batched image model upscale node

This commit is contained in:
kijai 2024-02-05 21:25:26 +02:00
parent 98d6af1ada
commit 0105e9d080

View File

@ -2140,7 +2140,7 @@ class BatchCLIPSeg:
model.to(device) # Ensure the model is on the correct device model.to(device) # Ensure the model is on the correct device
images = images.to(device) images = images.to(device)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
pbar = comfy.utils.ProgressBar(images.shape[0])
for image in images: for image in images:
image = (image* 255).type(torch.uint8) image = (image* 255).type(torch.uint8)
prompt = text prompt = text
@ -2165,7 +2165,7 @@ class BatchCLIPSeg:
# Remove the extra dimensions # Remove the extra dimensions
resized_tensor = resized_tensor[0, 0, :, :] resized_tensor = resized_tensor[0, 0, :, :]
pbar.update(1)
out.append(resized_tensor) out.append(resized_tensor)
results = torch.stack(out).cpu() results = torch.stack(out).cpu()
@ -3266,13 +3266,14 @@ class OffsetMaskByNormalizedAmplitude:
return offsetmask, return offsetmask,
class ImageTransformByNormalizedAmplitude: class ImageTransformByNormalizedAmplitude:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",), "normalized_amp": ("NORMALIZED_AMPLITUDE",),
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }), "zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
"x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"cumulative": ("BOOLEAN", { "default": False }), "cumulative": ("BOOLEAN", { "default": False }),
"image": ("IMAGE",), "image": ("IMAGE",),
}} }}
@ -3281,7 +3282,7 @@ class ImageTransformByNormalizedAmplitude:
FUNCTION = "amptransform" FUNCTION = "amptransform"
CATEGORY = "KJNodes" CATEGORY = "KJNodes"
def amptransform(self, image, normalized_amp, zoom_scale, cumulative): def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
# Ensure normalized_amp is an array and within the range [0, 1] # Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0) normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
transformed_images = [] transformed_images = []
@ -3325,6 +3326,17 @@ class ImageTransformByNormalizedAmplitude:
# Convert the tensor back to BxHxWxC format # Convert the tensor back to BxHxWxC format
tensor_img = tensor_img.permute(1, 2, 0) tensor_img = tensor_img.permute(1, 2, 0)
# Offset the image based on the amplitude
offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
# Apply the offset to the image tensor
if shift_x != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
if shift_y != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
# Add to the list # Add to the list
transformed_images.append(tensor_img) transformed_images.append(tensor_img)
@ -3460,7 +3472,7 @@ class GLIGENTextBoxApplyBatch:
interpolated_coords = interpolate_coordinates_with_curves(coordinates_dict, batch_size) interpolated_coords = interpolate_coordinates_with_curves(coordinates_dict, batch_size)
if interpolation == 'straight': if interpolation == 'straight':
interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size) interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size)
plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height) plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height)
for t in conditioning_to: for t in conditioning_to:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
@ -3471,6 +3483,7 @@ class GLIGENTextBoxApplyBatch:
x_position, y_position = interpolated_coords[i] x_position, y_position = interpolated_coords[i]
position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8) position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8)
position_params_batch[i].append(position_param) # Append position_param to the correct sublist position_params_batch[i].append(position_param) # Append position_param to the correct sublist
print("x ",x_position, "y ", y_position)
prev = [] prev = []
if "gligen" in n[1]: if "gligen" in n[1]:
prev = n[1]['gligen'][2] prev = n[1]['gligen'][2]
@ -3484,6 +3497,41 @@ class GLIGENTextBoxApplyBatch:
return (c, plot_image_tensor,) return (c, plot_image_tensor,)
class ImageUpscaleWithModelBatched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
"images": ("IMAGE",),
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "KJNodes"
def upscale(self, upscale_model, images, per_batch):
device = comfy.model_management.get_torch_device()
upscale_model.to(device)
in_img = images.movedim(-1,-3).to(device)
steps = in_img.shape[0]
pbar = comfy.utils.ProgressBar(steps)
t = []
for start_idx in range(0, in_img.shape[0], per_batch):
sub_images = upscale_model(in_img[start_idx:start_idx+per_batch])
t.append(sub_images)
# Calculate the number of images processed in this batch
batch_count = sub_images.shape[0]
# Update the progress bar by the number of images processed in this batch
pbar.update(batch_count)
upscale_model.cpu()
t = torch.cat(t, dim=0).permute(0, 2, 3, 1)
return (t,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant, "INTConstant": INTConstant,
@ -3548,7 +3596,8 @@ NODE_CLASS_MAPPINGS = {
"GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed, "GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed,
"StringConstant": StringConstant, "StringConstant": StringConstant,
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch, "GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
"CondPassThrough": CondPassThrough "CondPassThrough": CondPassThrough,
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant", "INTConstant": "INT Constant",
@ -3612,5 +3661,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed", "GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed",
"StringConstant": "StringConstant", "StringConstant": "StringConstant",
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch", "GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
"CondPassThrough": "CondPassThrough" "CondPassThrough": "CondPassThrough",
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched"
} }