support loading VAE decoder/encoder from the combined VAE model

This commit is contained in:
kijai 2024-11-11 12:42:49 +02:00
parent d3287d61b7
commit 8ebae92057
2 changed files with 52 additions and 6 deletions

View File

@ -31,10 +31,23 @@ class LatentPreviewer:
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self): def __init__(self):
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]] #latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] #latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
# if latent_rgb_factors_bias is not None: # if latent_rgb_factors_bias is not None:
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") # self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")

View File

@ -315,6 +315,16 @@ class MochiVAELoader:
dtype=dtype, dtype=dtype,
) )
vae_sd = load_torch_file(vae_path) vae_sd = load_torch_file(vae_path)
#support loading from combined VAE
if vae_sd.get("decoder.blocks.0.0.bias") is not None:
new_vae_sd = {}
for k, v in vae_sd.items():
if k.startswith("decoder."):
new_k = k[len("decoder."):]
new_vae_sd[new_k] = v
vae_sd = new_vae_sd
if is_accelerate_available: if is_accelerate_available:
for name, param in vae.named_parameters(): for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name]) set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
@ -383,6 +393,16 @@ class MochiVAEEncoderLoader:
) )
encoder_sd = load_torch_file(vae_path) encoder_sd = load_torch_file(vae_path)
#support loading from combined VAE
if encoder_sd.get("encoder.layers.0.bias") is not None:
new_vae_sd = {}
for k, v in encoder_sd.items():
if k.startswith("encoder."):
new_k = k[len("encoder."):]
new_vae_sd[new_k] = v
encoder_sd = new_vae_sd
if is_accelerate_available: if is_accelerate_available:
for name, param in encoder.named_parameters(): for name, param in encoder.named_parameters():
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name]) set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
@ -850,8 +870,21 @@ class MochiLatentPreview:
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] #latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
# import random # import random
# random.seed(seed) # random.seed(seed)
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)] # latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
@ -859,7 +892,7 @@ class MochiLatentPreview:
# print(latent_rgb_factors) # print(latent_rgb_factors)
latent_rgb_factors_bias = [0,0,0] latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1) latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)