diff --git a/latent_preview.py b/latent_preview.py index b1d08ad..ce8db6f 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -48,13 +48,11 @@ class Latent2RGBPreviewer(LatentPreviewer): ] self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453] - # if latent_rgb_factors_bias is not None: - # self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") def decode_latent_to_preview(self, x0): self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) if self.latent_rgb_factors_bias is not None: - self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) + self.latent_rgb_factors_bias = torch.tensor(self.latent_rgb_factors_bias, device="cpu").to(dtype=x0.dtype, device=x0.device) latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)