diff --git a/nodes.py b/nodes.py index cbcb03a..396a44b 100644 --- a/nodes.py +++ b/nodes.py @@ -493,7 +493,35 @@ class GetImageRangeFromBatch: raise ValueError("GetImageRangeFromBatch: End index is out of range") chosen_images = images[start_index:end_index] return (chosen_images, ) - + +class GetImagesFromBatchIndexed: + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "indexedimagesfrombatch" + CATEGORY = "KJNodes" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE",), + "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}), + }, + } + + def indexedimagesfrombatch(self, images, indexes): + + # Parse the indexes string into a list of integers + index_list = [int(index.strip()) for index in indexes.split(',')] + + # Convert list of indices to a PyTorch tensor + indices_tensor = torch.tensor(index_list, dtype=torch.long) + + # Select the images at the specified indices + chosen_images = images[indices_tensor] + + return (chosen_images,) + class ReplaceImagesInBatch: RETURN_TYPES = ("IMAGE",) @@ -2965,6 +2993,7 @@ NODE_CLASS_MAPPINGS = { "SoundReactive": SoundReactive, "GenerateNoise": GenerateNoise, "StableZero123_BatchSchedule": StableZero123_BatchSchedule, + "GetImagesFromBatchIndexed": GetImagesFromBatchIndexed, "ImageBatchRepeatEveryNth": ImageBatchRepeatEveryNth } NODE_DISPLAY_NAME_MAPPINGS = { @@ -3019,5 +3048,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SoundReactive": "SoundReactive", "GenerateNoise": "GenerateNoise", "StableZero123_BatchSchedule": "StableZero123_BatchSchedule", + "GetImagesFromBatchIndexed": "GetImagesFromBatchIndexed", "ImageBatchRepeatEveryNth": "ImageBatchRepeatEveryNth" } \ No newline at end of file