From 6829b0ccf8892f7c4fa8789e4dd551904deb32f8 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 25 Feb 2024 19:04:28 +0200 Subject: [PATCH] add batched encode/decode --- nodes.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index e3ff772..6f86da4 100644 --- a/nodes.py +++ b/nodes.py @@ -3686,6 +3686,7 @@ class Intrinsic_lora_sampling: "clip": ("CLIP", ), "vae": ("VAE", ), "image": ("IMAGE",), + "per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), }} @@ -3693,14 +3694,17 @@ class Intrinsic_lora_sampling: FUNCTION = "onestepsample" CATEGORY = "KJNodes" - def onestepsample(self, model, lora_name, clip, vae, image, text, task): + def onestepsample(self, model, lora_name, clip, vae, image, text, task, per_batch): pbar = comfy.utils.ProgressBar(3) - encoded_latent, = VAEEncode.encode(self, vae, image[:,:,:,:3]) - sample = encoded_latent["samples"] - noise = torch.zeros(sample.size(), dtype=sample.dtype, layout=sample.layout, device="cpu") + image_list = [] + for start_idx in range(0, image.shape[0], per_batch): + sub_pixels = vae.vae_encode_crop_pixels(image[start_idx:start_idx+per_batch]) + image_list.append(vae.encode(sub_pixels[:,:,:,:3])) + + sample = torch.cat(image_list, dim=0) + noise = torch.zeros(sample.size(), dtype=sample.dtype, layout=sample.layout, device="cpu") prompt = task + "," + text - print(prompt) positive, = CLIPTextEncode.encode(self, clip, prompt) pbar.update(1) negative = positive #negative shouldn't do anything in this scenario @@ -3726,8 +3730,15 @@ class Intrinsic_lora_sampling: denoise=1.0, disable_noise=True, start_step=0, last_step=1, force_full_denoise=True, noise_mask=None, callback=None, disable_pbar=True, seed=None)} pbar.update(1) - image_out, = VAEDecode.decode(self, vae, samples) + + decoded = [] + for start_idx in range(0, samples["samples"].shape[0], per_batch): + decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch])) + image_out = torch.cat(decoded, dim=0) + #image_out, = VAEDecode.decode(self, vae, samples) + pbar.update(1) + if task == 'depth map': imax = image_out.max() imin = image_out.min()