support loading VAE decoder/encoder from the combined VAE model
This commit is contained in:
parent
d3287d61b7
commit
8ebae92057
@ -31,10 +31,23 @@ class LatentPreviewer:
|
|||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
|
#latent_rgb_factors = [[0.05389399697934166, 0.025018778505575393, -0.009193515248318657], [0.02318250640590553, -0.026987363837713156, 0.040172639061236956], [0.046035451343323666, -0.02039565868920197, 0.01275569344290342], [-0.015559161155025095, 0.051403973219861246, 0.03179031307996347], [-0.02766167769640129, 0.03749545161530447, 0.003335141009473408], [0.05824598730479011, 0.021744367381243884, -0.01578925627951616], [0.05260929401500947, 0.0560165014956886, -0.027477296572565126], [0.018513891242931686, 0.041961785217662514, 0.004490763489747966], [0.024063060899760215, 0.065082853069653, 0.044343437673514896], [0.05250992323006226, 0.04361117432588933, 0.01030076055524387], [0.0038921710021782366, -0.025299228133723792, 0.019370764014574535], [-0.00011950534333568519, 0.06549370069727675, -0.03436712163379723], [-0.026020578032683626, -0.013341758571090847, -0.009119046570271953], [0.024412451175602937, 0.030135064560817174, -0.008355486384198006], [0.04002209845752687, -0.017341304390739463, 0.02818338690302971], [-0.032575108695213684, -0.009588338926775117, -0.03077312160940468]]
|
||||||
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||||
|
latent_rgb_factors =[
|
||||||
|
[-0.0069, -0.0045, 0.0018],
|
||||||
|
[ 0.0154, -0.0692, -0.0274],
|
||||||
|
[ 0.0333, 0.0019, 0.0206],
|
||||||
|
[-0.1390, 0.0628, 0.1678],
|
||||||
|
[-0.0725, 0.0134, -0.1898],
|
||||||
|
[ 0.0074, -0.0270, -0.0209],
|
||||||
|
[-0.0176, -0.0277, -0.0221],
|
||||||
|
[ 0.5294, 0.5204, 0.3852],
|
||||||
|
[-0.0326, -0.0446, -0.0143],
|
||||||
|
[-0.0659, 0.0153, -0.0153],
|
||||||
|
[ 0.0185, -0.0217, 0.0014],
|
||||||
|
[-0.0396, -0.0495, -0.0281]
|
||||||
|
]
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
self.latent_rgb_factors_bias = None
|
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||||
# if latent_rgb_factors_bias is not None:
|
# if latent_rgb_factors_bias is not None:
|
||||||
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
|
||||||
|
|||||||
39
nodes.py
39
nodes.py
@ -315,6 +315,16 @@ class MochiVAELoader:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
|
|
||||||
|
#support loading from combined VAE
|
||||||
|
if vae_sd.get("decoder.blocks.0.0.bias") is not None:
|
||||||
|
new_vae_sd = {}
|
||||||
|
for k, v in vae_sd.items():
|
||||||
|
if k.startswith("decoder."):
|
||||||
|
new_k = k[len("decoder."):]
|
||||||
|
new_vae_sd[new_k] = v
|
||||||
|
vae_sd = new_vae_sd
|
||||||
|
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for name, param in vae.named_parameters():
|
for name, param in vae.named_parameters():
|
||||||
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name])
|
||||||
@ -383,6 +393,16 @@ class MochiVAEEncoderLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_sd = load_torch_file(vae_path)
|
encoder_sd = load_torch_file(vae_path)
|
||||||
|
|
||||||
|
#support loading from combined VAE
|
||||||
|
if encoder_sd.get("encoder.layers.0.bias") is not None:
|
||||||
|
new_vae_sd = {}
|
||||||
|
for k, v in encoder_sd.items():
|
||||||
|
if k.startswith("encoder."):
|
||||||
|
new_k = k[len("encoder."):]
|
||||||
|
new_vae_sd[new_k] = v
|
||||||
|
encoder_sd = new_vae_sd
|
||||||
|
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for name, param in encoder.named_parameters():
|
for name, param in encoder.named_parameters():
|
||||||
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
|
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
|
||||||
@ -850,8 +870,21 @@ class MochiLatentPreview:
|
|||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
|
|
||||||
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
#latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||||
|
latent_rgb_factors =[
|
||||||
|
[-0.0069, -0.0045, 0.0018],
|
||||||
|
[ 0.0154, -0.0692, -0.0274],
|
||||||
|
[ 0.0333, 0.0019, 0.0206],
|
||||||
|
[-0.1390, 0.0628, 0.1678],
|
||||||
|
[-0.0725, 0.0134, -0.1898],
|
||||||
|
[ 0.0074, -0.0270, -0.0209],
|
||||||
|
[-0.0176, -0.0277, -0.0221],
|
||||||
|
[ 0.5294, 0.5204, 0.3852],
|
||||||
|
[-0.0326, -0.0446, -0.0143],
|
||||||
|
[-0.0659, 0.0153, -0.0153],
|
||||||
|
[ 0.0185, -0.0217, 0.0014],
|
||||||
|
[-0.0396, -0.0495, -0.0281]
|
||||||
|
]
|
||||||
# import random
|
# import random
|
||||||
# random.seed(seed)
|
# random.seed(seed)
|
||||||
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
|
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
|
||||||
@ -859,7 +892,7 @@ class MochiLatentPreview:
|
|||||||
# print(latent_rgb_factors)
|
# print(latent_rgb_factors)
|
||||||
|
|
||||||
|
|
||||||
latent_rgb_factors_bias = [0,0,0]
|
latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
|
||||||
|
|
||||||
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
||||||
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user