Add batched image model upscale node

This commit is contained in:
kijai 2024-02-05 21:25:26 +02:00
parent 98d6af1ada
commit 0105e9d080

View File

@ -2140,7 +2140,7 @@ class BatchCLIPSeg:
model.to(device) # Ensure the model is on the correct device
images = images.to(device)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
pbar = comfy.utils.ProgressBar(images.shape[0])
for image in images:
image = (image* 255).type(torch.uint8)
prompt = text
@ -2165,7 +2165,7 @@ class BatchCLIPSeg:
# Remove the extra dimensions
resized_tensor = resized_tensor[0, 0, :, :]
pbar.update(1)
out.append(resized_tensor)
results = torch.stack(out).cpu()
@ -3266,13 +3266,14 @@ class OffsetMaskByNormalizedAmplitude:
return offsetmask,
class ImageTransformByNormalizedAmplitude:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
"x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"cumulative": ("BOOLEAN", { "default": False }),
"image": ("IMAGE",),
}}
@ -3281,7 +3282,7 @@ class ImageTransformByNormalizedAmplitude:
FUNCTION = "amptransform"
CATEGORY = "KJNodes"
def amptransform(self, image, normalized_amp, zoom_scale, cumulative):
def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
transformed_images = []
@ -3325,6 +3326,17 @@ class ImageTransformByNormalizedAmplitude:
# Convert the tensor back to BxHxWxC format
tensor_img = tensor_img.permute(1, 2, 0)
# Offset the image based on the amplitude
offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
# Apply the offset to the image tensor
if shift_x != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
if shift_y != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
# Add to the list
transformed_images.append(tensor_img)
@ -3460,7 +3472,7 @@ class GLIGENTextBoxApplyBatch:
interpolated_coords = interpolate_coordinates_with_curves(coordinates_dict, batch_size)
if interpolation == 'straight':
interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size)
plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height)
for t in conditioning_to:
n = [t[0], t[1].copy()]
@ -3471,6 +3483,7 @@ class GLIGENTextBoxApplyBatch:
x_position, y_position = interpolated_coords[i]
position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8)
position_params_batch[i].append(position_param) # Append position_param to the correct sublist
print("x ",x_position, "y ", y_position)
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
@ -3484,6 +3497,41 @@ class GLIGENTextBoxApplyBatch:
return (c, plot_image_tensor,)
class ImageUpscaleWithModelBatched:
@classmethod
def INPUT_TYPES(s):
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
"images": ("IMAGE",),
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "KJNodes"
def upscale(self, upscale_model, images, per_batch):
device = comfy.model_management.get_torch_device()
upscale_model.to(device)
in_img = images.movedim(-1,-3).to(device)
steps = in_img.shape[0]
pbar = comfy.utils.ProgressBar(steps)
t = []
for start_idx in range(0, in_img.shape[0], per_batch):
sub_images = upscale_model(in_img[start_idx:start_idx+per_batch])
t.append(sub_images)
# Calculate the number of images processed in this batch
batch_count = sub_images.shape[0]
# Update the progress bar by the number of images processed in this batch
pbar.update(batch_count)
upscale_model.cpu()
t = torch.cat(t, dim=0).permute(0, 2, 3, 1)
return (t,)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
@ -3548,7 +3596,8 @@ NODE_CLASS_MAPPINGS = {
"GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed,
"StringConstant": StringConstant,
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
"CondPassThrough": CondPassThrough
"CondPassThrough": CondPassThrough,
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -3612,5 +3661,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed",
"StringConstant": "StringConstant",
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
"CondPassThrough": "CondPassThrough"
"CondPassThrough": "CondPassThrough",
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched"
}