mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 21:04:41 +08:00
Add batched image model upscale node
This commit is contained in:
parent
98d6af1ada
commit
0105e9d080
64
nodes.py
64
nodes.py
@ -2140,7 +2140,7 @@ class BatchCLIPSeg:
|
|||||||
model.to(device) # Ensure the model is on the correct device
|
model.to(device) # Ensure the model is on the correct device
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
|
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||||
for image in images:
|
for image in images:
|
||||||
image = (image* 255).type(torch.uint8)
|
image = (image* 255).type(torch.uint8)
|
||||||
prompt = text
|
prompt = text
|
||||||
@ -2165,7 +2165,7 @@ class BatchCLIPSeg:
|
|||||||
|
|
||||||
# Remove the extra dimensions
|
# Remove the extra dimensions
|
||||||
resized_tensor = resized_tensor[0, 0, :, :]
|
resized_tensor = resized_tensor[0, 0, :, :]
|
||||||
|
pbar.update(1)
|
||||||
out.append(resized_tensor)
|
out.append(resized_tensor)
|
||||||
|
|
||||||
results = torch.stack(out).cpu()
|
results = torch.stack(out).cpu()
|
||||||
@ -3266,13 +3266,14 @@ class OffsetMaskByNormalizedAmplitude:
|
|||||||
|
|
||||||
return offsetmask,
|
return offsetmask,
|
||||||
|
|
||||||
|
|
||||||
class ImageTransformByNormalizedAmplitude:
|
class ImageTransformByNormalizedAmplitude:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
|
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
|
||||||
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
|
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
|
||||||
|
"x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||||
|
"y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||||
"cumulative": ("BOOLEAN", { "default": False }),
|
"cumulative": ("BOOLEAN", { "default": False }),
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
}}
|
}}
|
||||||
@ -3281,7 +3282,7 @@ class ImageTransformByNormalizedAmplitude:
|
|||||||
FUNCTION = "amptransform"
|
FUNCTION = "amptransform"
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
def amptransform(self, image, normalized_amp, zoom_scale, cumulative):
|
def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
|
||||||
# Ensure normalized_amp is an array and within the range [0, 1]
|
# Ensure normalized_amp is an array and within the range [0, 1]
|
||||||
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
|
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
|
||||||
transformed_images = []
|
transformed_images = []
|
||||||
@ -3325,6 +3326,17 @@ class ImageTransformByNormalizedAmplitude:
|
|||||||
# Convert the tensor back to BxHxWxC format
|
# Convert the tensor back to BxHxWxC format
|
||||||
tensor_img = tensor_img.permute(1, 2, 0)
|
tensor_img = tensor_img.permute(1, 2, 0)
|
||||||
|
|
||||||
|
# Offset the image based on the amplitude
|
||||||
|
offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
|
||||||
|
shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
|
||||||
|
shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
|
||||||
|
|
||||||
|
# Apply the offset to the image tensor
|
||||||
|
if shift_x != 0:
|
||||||
|
tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
|
||||||
|
if shift_y != 0:
|
||||||
|
tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
|
||||||
|
|
||||||
# Add to the list
|
# Add to the list
|
||||||
transformed_images.append(tensor_img)
|
transformed_images.append(tensor_img)
|
||||||
|
|
||||||
@ -3460,7 +3472,7 @@ class GLIGENTextBoxApplyBatch:
|
|||||||
interpolated_coords = interpolate_coordinates_with_curves(coordinates_dict, batch_size)
|
interpolated_coords = interpolate_coordinates_with_curves(coordinates_dict, batch_size)
|
||||||
if interpolation == 'straight':
|
if interpolation == 'straight':
|
||||||
interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size)
|
interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size)
|
||||||
|
|
||||||
plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height)
|
plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height)
|
||||||
for t in conditioning_to:
|
for t in conditioning_to:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
@ -3471,6 +3483,7 @@ class GLIGENTextBoxApplyBatch:
|
|||||||
x_position, y_position = interpolated_coords[i]
|
x_position, y_position = interpolated_coords[i]
|
||||||
position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8)
|
position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8)
|
||||||
position_params_batch[i].append(position_param) # Append position_param to the correct sublist
|
position_params_batch[i].append(position_param) # Append position_param to the correct sublist
|
||||||
|
print("x ",x_position, "y ", y_position)
|
||||||
prev = []
|
prev = []
|
||||||
if "gligen" in n[1]:
|
if "gligen" in n[1]:
|
||||||
prev = n[1]['gligen'][2]
|
prev = n[1]['gligen'][2]
|
||||||
@ -3484,6 +3497,41 @@ class GLIGENTextBoxApplyBatch:
|
|||||||
|
|
||||||
return (c, plot_image_tensor,)
|
return (c, plot_image_tensor,)
|
||||||
|
|
||||||
|
class ImageUpscaleWithModelBatched:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "upscale"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
|
def upscale(self, upscale_model, images, per_batch):
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
upscale_model.to(device)
|
||||||
|
in_img = images.movedim(-1,-3).to(device)
|
||||||
|
|
||||||
|
steps = in_img.shape[0]
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
t = []
|
||||||
|
|
||||||
|
for start_idx in range(0, in_img.shape[0], per_batch):
|
||||||
|
sub_images = upscale_model(in_img[start_idx:start_idx+per_batch])
|
||||||
|
t.append(sub_images)
|
||||||
|
# Calculate the number of images processed in this batch
|
||||||
|
batch_count = sub_images.shape[0]
|
||||||
|
# Update the progress bar by the number of images processed in this batch
|
||||||
|
pbar.update(batch_count)
|
||||||
|
upscale_model.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
t = torch.cat(t, dim=0).permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
return (t,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
@ -3548,7 +3596,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed,
|
"GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed,
|
||||||
"StringConstant": StringConstant,
|
"StringConstant": StringConstant,
|
||||||
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
|
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
|
||||||
"CondPassThrough": CondPassThrough
|
"CondPassThrough": CondPassThrough,
|
||||||
|
"ImageUpscaleWithModelBatched": ImageUpscaleWithModelBatched
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -3612,5 +3661,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed",
|
"GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed",
|
||||||
"StringConstant": "StringConstant",
|
"StringConstant": "StringConstant",
|
||||||
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
|
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
|
||||||
"CondPassThrough": "CondPassThrough"
|
"CondPassThrough": "CondPassThrough",
|
||||||
|
"ImageUpscaleWithModelBatched": "ImageUpscaleWithModelBatched"
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user