mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 21:04:25 +08:00
260 lines
9.7 KiB
Python
260 lines
9.7 KiB
Python
# Adapted from OpenSora
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
|
# --------------------------------------------------------
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from einops import rearrange
|
|
from torch.distributions import LogisticNormal
|
|
from tqdm import tqdm
|
|
|
|
|
|
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|
"""
|
|
Extract values from a 1-D numpy array for a batch of indices.
|
|
:param arr: the 1-D numpy array.
|
|
:param timesteps: a tensor of indices into the array to extract.
|
|
:param broadcast_shape: a larger shape of K dimensions with the batch
|
|
dimension equal to the length of timesteps.
|
|
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
|
"""
|
|
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
|
while len(res.shape) < len(broadcast_shape):
|
|
res = res[..., None]
|
|
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
|
|
|
|
|
def mean_flat(tensor: torch.Tensor, mask=None):
|
|
"""
|
|
Take the mean over all non-batch dimensions.
|
|
"""
|
|
if mask is None:
|
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
|
else:
|
|
assert tensor.dim() == 5
|
|
assert tensor.shape[2] == mask.shape[1]
|
|
tensor = rearrange(tensor, "b c t h w -> b t (c h w)")
|
|
denom = mask.sum(dim=1) * tensor.shape[-1]
|
|
loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom
|
|
return loss
|
|
|
|
|
|
def timestep_transform(
|
|
t,
|
|
model_kwargs,
|
|
base_resolution=512 * 512,
|
|
base_num_frames=1,
|
|
scale=1.0,
|
|
num_timesteps=1,
|
|
):
|
|
t = t / num_timesteps
|
|
resolution = model_kwargs["height"] * model_kwargs["width"]
|
|
ratio_space = (resolution / base_resolution).sqrt()
|
|
# NOTE: currently, we do not take fps into account
|
|
# NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
|
|
if model_kwargs["num_frames"][0] == 1:
|
|
num_frames = torch.ones_like(model_kwargs["num_frames"])
|
|
else:
|
|
num_frames = model_kwargs["num_frames"] // 17 * 5
|
|
ratio_time = (num_frames / base_num_frames).sqrt()
|
|
|
|
ratio = ratio_space * ratio_time * scale
|
|
new_t = ratio * t / (1 + (ratio - 1) * t)
|
|
|
|
new_t = new_t * num_timesteps
|
|
return new_t
|
|
|
|
|
|
class RFlowScheduler:
|
|
def __init__(
|
|
self,
|
|
num_timesteps=1000,
|
|
num_sampling_steps=10,
|
|
use_discrete_timesteps=False,
|
|
sample_method="uniform",
|
|
loc=0.0,
|
|
scale=1.0,
|
|
use_timestep_transform=False,
|
|
transform_scale=1.0,
|
|
):
|
|
self.num_timesteps = num_timesteps
|
|
self.num_sampling_steps = num_sampling_steps
|
|
self.use_discrete_timesteps = use_discrete_timesteps
|
|
|
|
# sample method
|
|
assert sample_method in ["uniform", "logit-normal"]
|
|
assert (
|
|
sample_method == "uniform" or not use_discrete_timesteps
|
|
), "Only uniform sampling is supported for discrete timesteps"
|
|
self.sample_method = sample_method
|
|
if sample_method == "logit-normal":
|
|
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
|
|
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
|
|
|
|
# timestep transform
|
|
self.use_timestep_transform = use_timestep_transform
|
|
self.transform_scale = transform_scale
|
|
|
|
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
|
|
"""
|
|
Compute training losses for a single timestep.
|
|
Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses
|
|
Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0]
|
|
"""
|
|
if t is None:
|
|
if self.use_discrete_timesteps:
|
|
t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device)
|
|
elif self.sample_method == "uniform":
|
|
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps
|
|
elif self.sample_method == "logit-normal":
|
|
t = self.sample_t(x_start) * self.num_timesteps
|
|
|
|
if self.use_timestep_transform:
|
|
t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps)
|
|
|
|
if model_kwargs is None:
|
|
model_kwargs = {}
|
|
if noise is None:
|
|
noise = torch.randn_like(x_start)
|
|
assert noise.shape == x_start.shape
|
|
|
|
x_t = self.add_noise(x_start, noise, t)
|
|
if mask is not None:
|
|
t0 = torch.zeros_like(t)
|
|
x_t0 = self.add_noise(x_start, noise, t0)
|
|
x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0)
|
|
|
|
terms = {}
|
|
model_output = model(x_t, t, **model_kwargs)
|
|
velocity_pred = model_output.chunk(2, dim=1)[0]
|
|
if weights is None:
|
|
loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask)
|
|
else:
|
|
weight = _extract_into_tensor(weights, t, x_start.shape)
|
|
loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask)
|
|
terms["loss"] = loss
|
|
|
|
return terms
|
|
|
|
def add_noise(
|
|
self,
|
|
original_samples: torch.FloatTensor,
|
|
noise: torch.FloatTensor,
|
|
timesteps: torch.IntTensor,
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
compatible with diffusers add_noise()
|
|
"""
|
|
timepoints = timesteps.float() / self.num_timesteps
|
|
timepoints = 1 - timepoints # [1,1/1000]
|
|
|
|
# timepoint (bsz) noise: (bsz, 4, frame, w ,h)
|
|
# expand timepoint to noise shape
|
|
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
|
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
|
|
|
|
return timepoints * original_samples + (1 - timepoints) * noise
|
|
|
|
|
|
class RFLOW:
|
|
def __init__(
|
|
self,
|
|
num_sampling_steps=10,
|
|
num_timesteps=1000,
|
|
cfg_scale=4.0,
|
|
use_discrete_timesteps=False,
|
|
use_timestep_transform=False,
|
|
**kwargs,
|
|
):
|
|
self.num_sampling_steps = num_sampling_steps
|
|
self.num_timesteps = num_timesteps
|
|
self.cfg_scale = cfg_scale
|
|
self.use_discrete_timesteps = use_discrete_timesteps
|
|
self.use_timestep_transform = use_timestep_transform
|
|
|
|
self.scheduler = RFlowScheduler(
|
|
num_timesteps=num_timesteps,
|
|
num_sampling_steps=num_sampling_steps,
|
|
use_discrete_timesteps=use_discrete_timesteps,
|
|
use_timestep_transform=use_timestep_transform,
|
|
**kwargs,
|
|
)
|
|
|
|
def sample(
|
|
self,
|
|
model,
|
|
z,
|
|
model_args,
|
|
y_null,
|
|
device,
|
|
mask=None,
|
|
guidance_scale=None,
|
|
progress=True,
|
|
verbose=False,
|
|
):
|
|
# if no specific guidance scale is provided, use the default scale when initializing the scheduler
|
|
if guidance_scale is None:
|
|
guidance_scale = self.cfg_scale
|
|
|
|
# text encoding
|
|
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
|
|
|
|
# prepare timesteps
|
|
timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)]
|
|
if self.use_discrete_timesteps:
|
|
timesteps = [int(round(t)) for t in timesteps]
|
|
timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps]
|
|
if self.use_timestep_transform:
|
|
timesteps = [timestep_transform(t, model_args, num_timesteps=self.num_timesteps) for t in timesteps]
|
|
|
|
if mask is not None:
|
|
noise_added = torch.zeros_like(mask, dtype=torch.bool)
|
|
noise_added = noise_added | (mask == 1)
|
|
|
|
progress_wrap = tqdm if progress and dist.get_rank() == 0 else (lambda x: x)
|
|
|
|
dtype = model.x_embedder.proj.weight.dtype
|
|
all_timesteps = [int(t.to(dtype).item()) for t in timesteps]
|
|
for i, t in progress_wrap(list(enumerate(timesteps))):
|
|
# mask for adding noise
|
|
if mask is not None:
|
|
mask_t = mask * self.num_timesteps
|
|
x0 = z.clone()
|
|
x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t)
|
|
|
|
mask_t_upper = mask_t >= t.unsqueeze(1)
|
|
model_args["x_mask"] = mask_t_upper.repeat(2, 1)
|
|
mask_add_noise = mask_t_upper & ~noise_added
|
|
|
|
z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0)
|
|
noise_added = mask_t_upper
|
|
|
|
# classifier-free guidance
|
|
z_in = torch.cat([z, z], 0)
|
|
t = torch.cat([t, t], 0)
|
|
|
|
# pred = model(z_in, t, **model_args).chunk(2, dim=1)[0]
|
|
output = model(z_in, t, all_timesteps, **model_args)
|
|
|
|
pred = output.chunk(2, dim=1)[0]
|
|
pred_cond, pred_uncond = pred.chunk(2, dim=0)
|
|
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
|
|
|
# update z
|
|
dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
|
|
dt = dt / self.num_timesteps
|
|
z = z + v_pred * dt[:, None, None, None, None]
|
|
|
|
if mask is not None:
|
|
z = torch.where(mask_t_upper[:, None, :, None, None], z, x0)
|
|
|
|
return z
|
|
|
|
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
|
|
return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t)
|