Repeat last frame when concatenating different length videos

This commit is contained in:
kijai 2025-01-14 15:08:34 +02:00
parent 3adcc529f2
commit 28f0470a9a
2 changed files with 15 additions and 9 deletions

View File

@ -217,13 +217,13 @@ class ImageConcanate:
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "concanate"
FUNCTION = "concatenate"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
Concatenates the image2 to image1 in the specified direction.
"""
def concanate(self, image1, image2, direction, match_image_size, first_image_shape=None):
def concatenate(self, image1, image2, direction, match_image_size, first_image_shape=None):
# Check if the batch sizes are different
batch_size1 = image1.shape[0]
batch_size2 = image2.shape[0]
@ -231,12 +231,16 @@ Concatenates the image2 to image1 in the specified direction.
if batch_size1 != batch_size2:
# Calculate the number of repetitions needed
max_batch_size = max(batch_size1, batch_size2)
repeats1 = max_batch_size // batch_size1
repeats2 = max_batch_size // batch_size2
repeats1 = max_batch_size - batch_size1
repeats2 = max_batch_size - batch_size2
# Repeat the images to match the largest batch size
image1 = image1.repeat(repeats1, 1, 1, 1)
image2 = image2.repeat(repeats2, 1, 1, 1)
# Repeat the last image to match the largest batch size
if repeats1 > 0:
last_image1 = image1[-1].unsqueeze(0).repeat(repeats1, 1, 1, 1)
image1 = torch.cat([image1, last_image1], dim=0)
if repeats2 > 0:
last_image2 = image2[-1].unsqueeze(0).repeat(repeats2, 1, 1, 1)
image2 = torch.cat([image2, last_image2], dim=0)
if match_image_size:
# Use first_image_shape if provided; otherwise, default to image1's shape
@ -1845,7 +1849,7 @@ with the **inputcount** and clicking update.
first_image_shape = image.shape
for c in range(1, inputcount):
new_image = kwargs[f"image_{c + 1}"]
image, = ImageConcanate.concanate(self, image, new_image, direction, match_image_size, first_image_shape=first_image_shape)
image, = ImageConcanate.concatenate(self, image, new_image, direction, match_image_size, first_image_shape=first_image_shape)
first_image_shape = None
return (image,)

View File

@ -475,6 +475,7 @@ class TorchCompileCosmosModel:
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "tooltip": "Set the dynamo cache size limit"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -482,10 +483,11 @@ class TorchCompileCosmosModel:
CATEGORY = "KJNodes/experimental"
EXPERIMENTAL = True
def patch(self, model, backend, mode, fullgraph, dynamic):
def patch(self, model, backend, mode, fullgraph, dynamic, dynamo_cache_size_limit):
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
if not self._compiled:
try: