Update nodes.py

This commit is contained in:
Kijai 2024-01-29 19:07:53 +02:00
parent e227539a2a
commit be8e44db39

View File

@ -34,7 +34,7 @@ class INTConstant:
RETURN_NAMES = ("value",)
FUNCTION = "get_value"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/constants"
def get_value(self, value):
return (value,)
@ -51,16 +51,27 @@ class FloatConstant:
RETURN_NAMES = ("value",)
FUNCTION = "get_value"
CATEGORY = "KJNodes"
CATEGORY = "KJNodes/constants"
def get_value(self, value):
return (value,)
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
class StringConstant:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"string": ("STRING", {"default": '', "multiline": False}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "passtring"
CATEGORY = "KJNodes/constants"
def passtring(self, string):
return (string, )
class CreateFluidMask:
@ -526,6 +537,38 @@ class GetImagesFromBatchIndexed:
return (chosen_images,)
class GetLatentsFromBatchIndexed:
RETURN_TYPES = ("LATENT",)
FUNCTION = "indexedlatentsfrombatch"
CATEGORY = "KJNodes"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latents": ("LATENT",),
"indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}),
},
}
def indexedlatentsfrombatch(self, latents, indexes):
samples = latents.copy()
latent_samples = samples["samples"]
# Parse the indexes string into a list of integers
index_list = [int(index.strip()) for index in indexes.split(',')]
# Convert list of indices to a PyTorch tensor
indices_tensor = torch.tensor(index_list, dtype=torch.long)
# Select the latents at the specified indices
chosen_latents = latent_samples[indices_tensor]
samples["samples"] = chosen_latents
return (samples,)
class ReplaceImagesInBatch:
RETURN_TYPES = ("IMAGE",)
@ -3271,7 +3314,7 @@ class ImageTransformByNormalizedAmplitude:
return (transformed_batch,)
NODE_CLASS_MAPPINGS = {
"INTConstant": INTConstant,
"FloatConstant": FloatConstant,
@ -3331,7 +3374,9 @@ NODE_CLASS_MAPPINGS = {
"ImageBatchRepeatInterleaving": ImageBatchRepeatInterleaving,
"NormalizedAmplitudeToMask": NormalizedAmplitudeToMask,
"OffsetMaskByNormalizedAmplitude": OffsetMaskByNormalizedAmplitude,
"ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude
"ImageTransformByNormalizedAmplitude": ImageTransformByNormalizedAmplitude,
"GetLatentsFromBatchIndexed": GetLatentsFromBatchIndexed,
"StringConstant": StringConstant
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INTConstant": "INT Constant",
@ -3391,5 +3436,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageBatchRepeatInterleaving": "ImageBatchRepeatInterleaving",
"NormalizedAmplitudeToMask": "NormalizedAmplitudeToMask",
"OffsetMaskByNormalizedAmplitude": "OffsetMaskByNormalizedAmplitude",
"ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude"
"ImageTransformByNormalizedAmplitude": "ImageTransformByNormalizedAmplitude",
"GetLatentsFromBatchIndexed": "GetLatentsFromBatchIndexed",
"StringConstant": "StringConstant"
}