Update nodes.py

This commit is contained in:
kijai 2024-01-31 22:32:58 +02:00
parent 9b17cba78a
commit 58c4c755f7

116
nodes.py
View File

@ -842,6 +842,25 @@ class ConditioningMultiCombine:
cond = cond_combine_node.combine(new_cond, cond)[0]
return (cond, inputcount,)
class CondPassThrough:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
},
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING",)
RETURN_NAMES = ("positive", "negative")
FUNCTION = "passthrough"
CATEGORY = "KJNodes/misc"
def passthrough(self, positive, negative):
return (positive, negative,)
def append_helper(t, mask, c, set_area_to_bounds, strength):
n = [t[0], t[1].copy()]
_, h, w = mask.shape
@ -3314,6 +3333,73 @@ class ImageTransformByNormalizedAmplitude:
return (transformed_batch,)
def parse_coordinates(coordinates_str):
coordinates = {}
pattern = r'(\d+):\((\d+),(\d+)\)'
matches = re.findall(pattern, coordinates_str)
for match in matches:
index, x, y = map(int, match)
coordinates[index] = (x, y)
return coordinates
def interpolate_coordinates(coordinates_dict, batch_size):
sorted_coords = sorted(coordinates_dict.items())
interpolated = {}
for i, ((index1, (x1, y1)), (index2, (x2, y2))) in enumerate(zip(sorted_coords, sorted_coords[1:])):
distance = index2 - index1
x_step = (x2 - x1) / distance
y_step = (y2 - y1) / distance
for j in range(distance):
interpolated_x = round(x1 + j * x_step)
interpolated_y = round(y1 + j * y_step)
interpolated[index1 + j] = (interpolated_x, interpolated_y)
interpolated[sorted_coords[-1][0]] = sorted_coords[-1][1]
# Ensure we have coordinates for all indices in the batch
last_index, last_coords = sorted_coords[-1]
for i in range(last_index + 1, batch_size):
interpolated[i] = last_coords
return interpolated
def plot_to_tensor(coordinates_dict, interpolated_dict, height, width, box_size):
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.patches as patches
original_x, original_y = zip(*coordinates_dict.values())
interpolated_x, interpolated_y = zip(*interpolated_dict.values())
fig, ax = plt.subplots(figsize=(width/100, height/100), dpi=100)
ax.scatter(original_x, original_y, color='blue', label='Original Points')
ax.scatter(interpolated_x, interpolated_y, color='red', alpha=0.5, label='Interpolated Points')
ax.plot(interpolated_x, interpolated_y, color='grey', linestyle='--', linewidth=0.5)
# Draw a box at each interpolated coordinate
for x, y in interpolated_dict.values():
rect = patches.Rectangle((x - box_size/2, y - box_size/2), box_size, box_size,
linewidth=1, edgecolor='green', facecolor='none')
ax.add_patch(rect)
ax.set_title('Interpolated Coordinates')
ax.set_xlabel('X Coordinate')
ax.set_ylabel('Y Coordinate')
ax.legend()
ax.set_xlim(0, width) # Set the x-axis to match the input latent width
ax.set_ylim(height, 0) # Set the y-axis to match the input latent height, with (0,0) at top-left
canvas = FigureCanvas(fig)
canvas.draw()
width, height = fig.get_size_inches() * fig.get_dpi()
image_np = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
image_tensor = torch.from_numpy(image_np).float() / 255.0
image_tensor = image_tensor.unsqueeze(0)
plt.close(fig)
return image_tensor
class GLIGENTextBoxApplyBatch:
@classmethod
def INPUT_TYPES(s):
@ -3324,33 +3410,34 @@ class GLIGENTextBoxApplyBatch:
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"x_increment": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}),
"y_increment": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 8}),
"coordinates": ("STRING", {"multiline": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
RETURN_TYPES = ("CONDITIONING", "IMAGE",)
FUNCTION = "append"
CATEGORY = "conditioning/gligen"
def append(self, latents, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y, x_increment, y_increment):
def append(self, latents, conditioning_to, clip, gligen_textbox_model, text, width, height, coordinates):
coordinates_dict = parse_coordinates(coordinates)
batch_size = sum(tensor.size(0) for tensor in latents.values())
print(batch_size)
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
# Interpolate coordinates for the entire batch
interpolated_coords = interpolate_coordinates(coordinates_dict, batch_size)
plot_image_tensor = plot_to_tensor(coordinates_dict, interpolated_coords, 512, 512, height)
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params_batch = [[] for _ in range(batch_size)] # Initialize a list of empty lists for each batch item
for i in range(batch_size):
x_position = x + i * x_increment
y_position = y + i * y_increment
x_position, y_position = interpolated_coords[i]
position_param = (cond_pooled, height // 8, width // 8, y_position // 8, x_position // 8)
position_params_batch[i].append(position_param) # Append position_param to the correct sublist
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
@ -3362,7 +3449,8 @@ class GLIGENTextBoxApplyBatch:
n[1]['gligen'] = ("position", gligen_textbox_model, combined_position_params)
c.append(n)
return (c, )
return (c, plot_image_tensor,)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
@ -3426,7 +3514,8 @@ NODE_CLASS_MAPPINGS = {
"ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude,
"GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed,
"StringConstant": StringConstant,
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch
"GLIGENTextBoxApplyBatch": GLIGENTextBoxApplyBatch,
"CondPassThrough": CondPassThrough
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -3489,5 +3578,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude",
"GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed",
"StringConstant": "StringConstant",
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch"
"GLIGENTextBoxApplyBatch": "GLIGENTextBoxApplyBatch",
"CondPassThrough": "CondPassThrough"
}