Cleanup, fix seed gen, better warnings for decoder

This commit is contained in:
kijai 2024-10-24 12:45:01 +03:00
parent f714748ad4
commit 257c526125
2 changed files with 83 additions and 123 deletions

View File

@ -1,17 +1,14 @@
import json
import random
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch import nn
from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm
from contextlib import nullcontext
try:
@ -86,29 +83,6 @@ def compute_packed_indices(
"valid_token_indices_kv": valid_token_indices,
}
def shift_sigma(
sigma: np.ndarray,
shift: float,
):
"""Shift noise standard deviation toward higher values.
Useful for training a model at high resolutions,
or sampling more finely at high noise levels.
Equivalent to:
sigma_shift = shift / (shift + 1 / sigma - 1)
except for sigma = 0.
Args:
sigma: noise standard deviation in [0, 1]
shift: shift factor >= 1.
For shift > 1, shifts sigma to higher values.
For shift = 1, identity function.
"""
return shift * sigma / (shift * sigma + 1 - sigma)
class T2VSynthMochiModel:
def __init__(
self,
@ -239,23 +213,16 @@ class T2VSynthMochiModel:
@torch.inference_mode(mode=True)
def run(self, args, stream_results):
random.seed(args["seed"])
np.random.seed(args["seed"])
torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])
generator = torch.Generator(device=self.device)
generator.manual_seed(args["seed"])
# assert (
# len(args["prompt"]) == 1
# ), f"Expected exactly one prompt, got {len(args['prompt'])}"
#prompt = args["prompt"][0]
#neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else ""
B = 1
w = args["width"]
h = args["height"]
t = args["num_frames"]
num_frames = args["num_frames"]
height = args["height"]
width = args["width"]
batch_cfg = args["mochi_args"]["batch_cfg"]
sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule")
@ -267,7 +234,7 @@ class T2VSynthMochiModel:
assert (
len(sigma_schedule) == sample_steps + 1
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}"
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
# if batch_cfg:
# sample_batched = self.get_conditioning(
@ -277,15 +244,19 @@ class T2VSynthMochiModel:
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
# create z
spatial_downsample = 8
temporal_downsample = 6
latent_t = (t - 1) // temporal_downsample + 1
latent_w, latent_h = w // spatial_downsample, h // spatial_downsample
latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h)
in_channels = 12
B = 1
C = in_channels
T = (num_frames - 1) // temporal_downsample + 1
H = height // spatial_downsample
W = width // spatial_downsample
latent_dims = dict(lT=T, lW=W, lH=H)
z = torch.randn(
(B, in_channels, latent_t, latent_h, latent_w),
(B, C, T, H, W),
device=self.device,
generator=generator,
dtype=torch.float32,
@ -307,22 +278,6 @@ class T2VSynthMochiModel:
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
# print(sample["y_mask"])
# print(type(sample["y_mask"]))
# print(sample["y_mask"][0].shape)
# print(sample["y_feat"])
# print(type(sample["y_feat"]))
# print(sample["y_feat"][0].shape)
# print(sample_null["y_mask"])
# print(type(sample_null["y_mask"]))
# print(sample_null["y_mask"][0].shape)
# print(sample_null["y_feat"])
# print(type(sample_null["y_feat"]))
# print(sample_null["y_feat"][0].shape)
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
@ -331,8 +286,6 @@ class T2VSynthMochiModel:
)
def model_fn(*, z, sigma, cfg_scale):
#print("z", z.dtype, z.device)
#print("sigma", sigma.dtype, sigma.device)
self.dit.to(self.device)
# if batch_cfg:
# with torch.autocast("cuda", dtype=torch.bfloat16):
@ -341,7 +294,7 @@ class T2VSynthMochiModel:
#else:
nonlocal sample, sample_null
with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
assert out_cond.shape == out_uncond.shape
@ -364,8 +317,6 @@ class T2VSynthMochiModel:
pred = pred.to(z)
output_cond = output_cond.to(z)
#if stream_results:
# yield i / sample_steps, None, False
z = z + dsigma * pred
comfy_pbar.update(1)

123
nodes.py
View File

@ -248,7 +248,7 @@ class MochiDecode:
"samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of frames in latent space (downscale factor is 6) to decode at once"}),
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
@ -268,6 +268,17 @@ class MochiDecode:
samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device)
B, C, T, H, W = samples.shape
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
@ -284,70 +295,68 @@ class MochiDecode:
x / blend_extent
)
return b
def decode_tiled(samples):
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
comfy_pbar = ProgressBar(len(range(0, H, overlap_height)))
rows = []
for i in tqdm(range(0, H, overlap_height), desc="Processing rows"):
row = []
for j in tqdm(range(0, W, overlap_width), desc="Processing columns", leave=False):
time = []
for k in tqdm(range(T // frame_batch_size), desc="Processing frames", leave=False):
remaining_frames = T % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = samples[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tile = vae(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
comfy_pbar.update(1)
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
result_rows = []
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
result_row = []
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
return torch.cat(result_rows, dim=3)
vae.to(device)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
if not enable_vae_tiling:
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if enable_vae_tiling and frame_batch_size > T:
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
samples = vae(samples)
elif not enable_vae_tiling:
logging.warning("Attempting to decode without tiling, very memory intensive")
samples = vae(samples)
else:
batch_size, num_channels, num_frames, height, width = samples.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
rows = []
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
row = []
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
time = []
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = samples[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tile = vae(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
comfy_pbar.update(1)
result_rows = []
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
result_row = []
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
samples = torch.cat(result_rows, dim=3)
logging.info("Decoding with tiling")
samples = decode_tiled(samples)
vae.to(offload_device)
#print("samples", samples.shape, samples.dtype, samples.device)
samples = samples.float()
samples = (samples + 1.0) / 2.0