From 8ebae92057a951d1c95547c4e1f533e17fe80156 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 11 Nov 2024 12:42:49 +0200 Subject: [PATCH] support loading VAE decoder/encoder from the combined VAE model --- latent_preview.py | 19 ++++++++++++++++--- nodes.py | 39 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/latent_preview.py b/latent_preview.py index b23eda2..b1d08ad 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -31,10 +31,23 @@ class LatentPreviewer: class Latent2RGBPreviewer(LatentPreviewer): def __init__(self): #latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]] - latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] - + #latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] + latent_rgb_factors =[ + [-0.0069, -0.0045, 0.0018], + [ 0.0154, -0.0692, -0.0274], + [ 0.0333, 0.0019, 0.0206], + [-0.1390, 0.0628, 0.1678], + [-0.0725, 0.0134, -0.1898], + [ 0.0074, -0.0270, -0.0209], + [-0.0176, -0.0277, -0.0221], + [ 0.5294, 0.5204, 0.3852], + [-0.0326, -0.0446, -0.0143], + [-0.0659, 0.0153, -0.0153], + [ 0.0185, -0.0217, 0.0014], + [-0.0396, -0.0495, -0.0281] + ] self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) - self.latent_rgb_factors_bias = None + self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453] # if latent_rgb_factors_bias is not None: # self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") diff --git a/nodes.py b/nodes.py index 89c72c1..f1b10b9 100644 --- a/nodes.py +++ b/nodes.py @@ -315,6 +315,16 @@ class MochiVAELoader: dtype=dtype, ) vae_sd = load_torch_file(vae_path) + + #support loading from combined VAE + if vae_sd.get("decoder.blocks.0.0.bias") is not None: + new_vae_sd = {} + for k, v in vae_sd.items(): + if k.startswith("decoder."): + new_k = k[len("decoder."):] + new_vae_sd[new_k] = v + vae_sd = new_vae_sd + if is_accelerate_available: for name, param in vae.named_parameters(): set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name]) @@ -383,6 +393,16 @@ class MochiVAEEncoderLoader: ) encoder_sd = load_torch_file(vae_path) + + #support loading from combined VAE + if encoder_sd.get("encoder.layers.0.bias") is not None: + new_vae_sd = {} + for k, v in encoder_sd.items(): + if k.startswith("encoder."): + new_k = k[len("encoder."):] + new_vae_sd[new_k] = v + encoder_sd = new_vae_sd + if is_accelerate_available: for name, param in encoder.named_parameters(): set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name]) @@ -850,8 +870,21 @@ class MochiLatentPreview: device = mm.get_torch_device() offload_device = mm.unet_offload_device() - latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] - + #latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] + latent_rgb_factors =[ + [-0.0069, -0.0045, 0.0018], + [ 0.0154, -0.0692, -0.0274], + [ 0.0333, 0.0019, 0.0206], + [-0.1390, 0.0628, 0.1678], + [-0.0725, 0.0134, -0.1898], + [ 0.0074, -0.0270, -0.0209], + [-0.0176, -0.0277, -0.0221], + [ 0.5294, 0.5204, 0.3852], + [-0.0326, -0.0446, -0.0143], + [-0.0659, 0.0153, -0.0153], + [ 0.0185, -0.0217, 0.0014], + [-0.0396, -0.0495, -0.0281] + ] # import random # random.seed(seed) # latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)] @@ -859,7 +892,7 @@ class MochiLatentPreview: # print(latent_rgb_factors) - latent_rgb_factors_bias = [0,0,0] + latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453] latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1) latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)