mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-27 12:45:43 +08:00
Add StableZero123_BatchSchedule -node
This commit is contained in:
parent
004c4120e8
commit
2a733b9b7d
174
nodes.py
174
nodes.py
@ -307,7 +307,7 @@ class CreateFadeMaskAdvanced:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n16:(0.0)\n", "multiline": True}),
|
"points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
|
||||||
"invert": ("BOOLEAN", {"default": False}),
|
"invert": ("BOOLEAN", {"default": False}),
|
||||||
"frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}),
|
"frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}),
|
||||||
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
||||||
@ -2575,7 +2575,7 @@ class AddLabel:
|
|||||||
"font_size": ("INT", {"default": 32, "min": 0, "max": 4096, "step": 1}),
|
"font_size": ("INT", {"default": 32, "min": 0, "max": 4096, "step": 1}),
|
||||||
"font_color": ("STRING", {"default": "white"}),
|
"font_color": ("STRING", {"default": "white"}),
|
||||||
"label_color": ("STRING", {"default": "black"}),
|
"label_color": ("STRING", {"default": "black"}),
|
||||||
"font_path": ("STRING", {"default": "fonts\\TTNorms-Black.otf"}),
|
"font": ("STRING", {"default": "TTNorms-Black.otf"}),
|
||||||
"text": ("STRING", {"default": "Text"}),
|
"text": ("STRING", {"default": "Text"}),
|
||||||
"direction": (
|
"direction": (
|
||||||
[ 'up',
|
[ 'up',
|
||||||
@ -2593,12 +2593,12 @@ class AddLabel:
|
|||||||
|
|
||||||
CATEGORY = "KJNodes"
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font_path, direction):
|
def addlabel(self, image, text_x, text_y, text, height, font_size, font_color, label_color, font, direction):
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
|
|
||||||
if font_path == "fonts\\TTNorms-Black.otf": #I don't know why relative path won't work otherwise...
|
if font == "TTNorms-Black.otf":
|
||||||
font_path = os.path.join(script_dir, font_path)
|
font_path = os.path.join(script_dir, "fonts", "TTNorms-Black.otf")
|
||||||
|
|
||||||
label_image = Image.new("RGB", (width, height), label_color)
|
label_image = Image.new("RGB", (width, height), label_color)
|
||||||
draw = ImageDraw.Draw(label_image)
|
draw = ImageDraw.Draw(label_image)
|
||||||
@ -2732,6 +2732,164 @@ class GenerateNoise:
|
|||||||
noise = noise[0].repeat(batch_size, 1, 1, 1)
|
noise = noise[0].repeat(batch_size, 1, 1, 1)
|
||||||
return ({"samples":noise}, )
|
return ({"samples":noise}, )
|
||||||
|
|
||||||
|
def camera_embeddings(elevation, azimuth):
|
||||||
|
elevation = torch.as_tensor([elevation])
|
||||||
|
azimuth = torch.as_tensor([azimuth])
|
||||||
|
embeddings = torch.stack(
|
||||||
|
[
|
||||||
|
torch.deg2rad(
|
||||||
|
(90 - elevation) - (90)
|
||||||
|
), # Zero123 polar is 90-elevation
|
||||||
|
torch.sin(torch.deg2rad(azimuth)),
|
||||||
|
torch.cos(torch.deg2rad(azimuth)),
|
||||||
|
torch.deg2rad(
|
||||||
|
90 - torch.full_like(elevation, 0)
|
||||||
|
),
|
||||||
|
], dim=-1).unsqueeze(1)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def interpolate_angle(start, end, fraction):
|
||||||
|
# Calculate the difference in angles and adjust for wraparound if necessary
|
||||||
|
diff = (end - start + 540) % 360 - 180
|
||||||
|
# Apply fraction to the difference
|
||||||
|
interpolated = start + fraction * diff
|
||||||
|
# Normalize the result to be within the range of -180 to 180
|
||||||
|
return (interpolated + 180) % 360 - 180
|
||||||
|
|
||||||
|
class StableZero123_BatchSchedule:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||||
|
"init_image": ("IMAGE",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
|
||||||
|
"azimuth_points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
|
||||||
|
"elevation_points_string": ("STRING", {"default": "0:(0.0),\n7:(0.0),\n15:(0.0)\n", "multiline": True}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "KJNodes"
|
||||||
|
|
||||||
|
def encode(self, clip_vision, init_image, vae, width, height, batch_size, azimuth_points_string, elevation_points_string, interpolation):
|
||||||
|
output = clip_vision.encode_image(init_image)
|
||||||
|
pooled = output.image_embeds.unsqueeze(0)
|
||||||
|
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||||
|
encode_pixels = pixels[:,:,:,:3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
|
||||||
|
def ease_in(t):
|
||||||
|
return t * t
|
||||||
|
def ease_out(t):
|
||||||
|
return 1 - (1 - t) * (1 - t)
|
||||||
|
def ease_in_out(t):
|
||||||
|
return 3 * t * t - 2 * t * t * t
|
||||||
|
|
||||||
|
# Parse the azimuth input string into a list of tuples
|
||||||
|
azimuth_points = []
|
||||||
|
azimuth_points_string = azimuth_points_string.rstrip(',\n')
|
||||||
|
for point_str in azimuth_points_string.split(','):
|
||||||
|
frame_str, azimuth_str = point_str.split(':')
|
||||||
|
frame = int(frame_str.strip())
|
||||||
|
azimuth = float(azimuth_str.strip()[1:-1])
|
||||||
|
azimuth_points.append((frame, azimuth))
|
||||||
|
# Sort the points by frame number
|
||||||
|
azimuth_points.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# Parse the elevation input string into a list of tuples
|
||||||
|
elevation_points = []
|
||||||
|
elevation_points_string = elevation_points_string.rstrip(',\n')
|
||||||
|
for point_str in elevation_points_string.split(','):
|
||||||
|
frame_str, elevation_str = point_str.split(':')
|
||||||
|
frame = int(frame_str.strip())
|
||||||
|
elevation_val = float(elevation_str.strip()[1:-1])
|
||||||
|
elevation_points.append((frame, elevation_val))
|
||||||
|
# Sort the points by frame number
|
||||||
|
elevation_points.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# Index of the next point to interpolate towards
|
||||||
|
next_point = 1
|
||||||
|
next_elevation_point = 1
|
||||||
|
|
||||||
|
positive_cond_out = []
|
||||||
|
positive_pooled_out = []
|
||||||
|
negative_cond_out = []
|
||||||
|
negative_pooled_out = []
|
||||||
|
|
||||||
|
#azimuth interpolation
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Find the interpolated azimuth for the current frame
|
||||||
|
while next_point < len(azimuth_points) and i >= azimuth_points[next_point][0]:
|
||||||
|
next_point += 1
|
||||||
|
# If next_point is equal to the length of points, we've gone past the last point
|
||||||
|
if next_point == len(azimuth_points):
|
||||||
|
next_point -= 1 # Set next_point to the last index of points
|
||||||
|
prev_point = max(next_point - 1, 0) # Ensure prev_point is not less than 0
|
||||||
|
|
||||||
|
# Calculate fraction
|
||||||
|
if azimuth_points[next_point][0] != azimuth_points[prev_point][0]: # Prevent division by zero
|
||||||
|
fraction = (i - azimuth_points[prev_point][0]) / (azimuth_points[next_point][0] - azimuth_points[prev_point][0])
|
||||||
|
if interpolation == "ease_in":
|
||||||
|
fraction = ease_in(fraction)
|
||||||
|
elif interpolation == "ease_out":
|
||||||
|
fraction = ease_out(fraction)
|
||||||
|
elif interpolation == "ease_in_out":
|
||||||
|
fraction = ease_in_out(fraction)
|
||||||
|
|
||||||
|
# Use the new interpolate_angle function
|
||||||
|
interpolated_azimuth = interpolate_angle(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction)
|
||||||
|
else:
|
||||||
|
interpolated_azimuth = azimuth_points[prev_point][1]
|
||||||
|
|
||||||
|
# Interpolate the elevation
|
||||||
|
next_elevation_point = 1
|
||||||
|
while next_elevation_point < len(elevation_points) and i >= elevation_points[next_elevation_point][0]:
|
||||||
|
next_elevation_point += 1
|
||||||
|
if next_elevation_point == len(elevation_points):
|
||||||
|
next_elevation_point -= 1
|
||||||
|
prev_elevation_point = max(next_elevation_point - 1, 0)
|
||||||
|
|
||||||
|
if elevation_points[next_elevation_point][0] != elevation_points[prev_elevation_point][0]:
|
||||||
|
fraction = (i - elevation_points[prev_elevation_point][0]) / (elevation_points[next_elevation_point][0] - elevation_points[prev_elevation_point][0])
|
||||||
|
if interpolation == "ease_in":
|
||||||
|
fraction = ease_in(fraction)
|
||||||
|
elif interpolation == "ease_out":
|
||||||
|
fraction = ease_out(fraction)
|
||||||
|
elif interpolation == "ease_in_out":
|
||||||
|
fraction = ease_in_out(fraction)
|
||||||
|
|
||||||
|
interpolated_elevation = interpolate_angle(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction)
|
||||||
|
else:
|
||||||
|
interpolated_elevation = elevation_points[prev_elevation_point][1]
|
||||||
|
|
||||||
|
cam_embeds = camera_embeddings(interpolated_elevation, interpolated_azimuth)
|
||||||
|
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
|
||||||
|
|
||||||
|
positive_pooled_out.append(t)
|
||||||
|
positive_cond_out.append(cond)
|
||||||
|
negative_pooled_out.append(torch.zeros_like(t))
|
||||||
|
negative_cond_out.append(torch.zeros_like(pooled))
|
||||||
|
|
||||||
|
# Concatenate the conditions and pooled outputs
|
||||||
|
final_positive_cond = torch.cat(positive_cond_out, dim=0)
|
||||||
|
final_positive_pooled = torch.cat(positive_pooled_out, dim=0)
|
||||||
|
final_negative_cond = torch.cat(negative_cond_out, dim=0)
|
||||||
|
final_negative_pooled = torch.cat(negative_pooled_out, dim=0)
|
||||||
|
|
||||||
|
# Structure the final output
|
||||||
|
final_positive = [[final_positive_cond, {"concat_latent_image": final_positive_pooled}]]
|
||||||
|
final_negative = [[final_negative_cond, {"concat_latent_image": final_negative_pooled}]]
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||||
|
return (final_positive, final_negative, {"samples": latent})
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"INTConstant": INTConstant,
|
"INTConstant": INTConstant,
|
||||||
"FloatConstant": FloatConstant,
|
"FloatConstant": FloatConstant,
|
||||||
@ -2783,7 +2941,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"AddLabel": AddLabel,
|
"AddLabel": AddLabel,
|
||||||
"ReferenceOnlySimple3": ReferenceOnlySimple3,
|
"ReferenceOnlySimple3": ReferenceOnlySimple3,
|
||||||
"SoundReactive": SoundReactive,
|
"SoundReactive": SoundReactive,
|
||||||
"GenerateNoise": GenerateNoise
|
"GenerateNoise": GenerateNoise,
|
||||||
|
"StableZero123_BatchSchedule": StableZero123_BatchSchedule,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"INTConstant": "INT Constant",
|
"INTConstant": "INT Constant",
|
||||||
@ -2835,5 +2994,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"AddLabel": "AddLabel",
|
"AddLabel": "AddLabel",
|
||||||
"ReferenceOnlySimple3": "ReferenceOnlySimple3",
|
"ReferenceOnlySimple3": "ReferenceOnlySimple3",
|
||||||
"SoundReactive": "SoundReactive",
|
"SoundReactive": "SoundReactive",
|
||||||
"GenerateNoise": "GenerateNoise"
|
"GenerateNoise": "GenerateNoise",
|
||||||
|
"StableZero123_BatchSchedule": "StableZero123_BatchSchedule",
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user