This commit is contained in:
kijai 2024-11-11 15:54:52 +02:00
parent 8ebae92057
commit e1bd05240a

View File

@ -48,13 +48,11 @@ class Latent2RGBPreviewer(LatentPreviewer):
]
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
# if latent_rgb_factors_bias is not None:
# self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
self.latent_rgb_factors_bias = torch.tensor(self.latent_rgb_factors_bias, device="cpu").to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors,
bias=self.latent_rgb_factors_bias)