36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
"""Container for latent space posterior."""
|
|
|
|
import torch
|
|
|
|
|
|
class LatentDistribution:
|
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
|
"""Initialize latent distribution.
|
|
|
|
Args:
|
|
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
|
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
|
"""
|
|
assert mean.shape == logvar.shape
|
|
self.mean = mean
|
|
self.logvar = logvar
|
|
|
|
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
|
if temperature == 0.0:
|
|
return self.mean
|
|
|
|
if noise is None:
|
|
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
|
else:
|
|
assert noise.device == self.mean.device
|
|
noise = noise.to(self.mean.dtype)
|
|
|
|
if temperature != 1.0:
|
|
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
|
|
|
# Just Gaussian sample with no scaling of variance.
|
|
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
|
|
|
def mode(self):
|
|
return self.mean
|