Cleanup, fix seed gen, better warnings for decoder

This commit is contained in:
kijai 2024-10-24 12:45:01 +03:00
parent f714748ad4
commit 257c526125
2 changed files with 83 additions and 123 deletions

View File

@ -1,17 +1,14 @@
import json import json
import random
from typing import Dict, List from typing import Dict, List
import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from torch import nn
from .dit.joint_model.context_parallel import get_cp_rank_size from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm
from contextlib import nullcontext from contextlib import nullcontext
try: try:
@ -86,29 +83,6 @@ def compute_packed_indices(
"valid_token_indices_kv": valid_token_indices, "valid_token_indices_kv": valid_token_indices,
} }
def shift_sigma(
sigma: np.ndarray,
shift: float,
):
"""Shift noise standard deviation toward higher values.
Useful for training a model at high resolutions,
or sampling more finely at high noise levels.
Equivalent to:
sigma_shift = shift / (shift + 1 / sigma - 1)
except for sigma = 0.
Args:
sigma: noise standard deviation in [0, 1]
shift: shift factor >= 1.
For shift > 1, shifts sigma to higher values.
For shift = 1, identity function.
"""
return shift * sigma / (shift * sigma + 1 - sigma)
class T2VSynthMochiModel: class T2VSynthMochiModel:
def __init__( def __init__(
self, self,
@ -239,23 +213,16 @@ class T2VSynthMochiModel:
@torch.inference_mode(mode=True) @torch.inference_mode(mode=True)
def run(self, args, stream_results): def run(self, args, stream_results):
random.seed(args["seed"])
np.random.seed(args["seed"])
torch.manual_seed(args["seed"]) torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])
generator = torch.Generator(device=self.device) generator = torch.Generator(device=self.device)
generator.manual_seed(args["seed"]) generator.manual_seed(args["seed"])
# assert ( num_frames = args["num_frames"]
# len(args["prompt"]) == 1 height = args["height"]
# ), f"Expected exactly one prompt, got {len(args['prompt'])}" width = args["width"]
#prompt = args["prompt"][0]
#neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else ""
B = 1
w = args["width"]
h = args["height"]
t = args["num_frames"]
batch_cfg = args["mochi_args"]["batch_cfg"] batch_cfg = args["mochi_args"]["batch_cfg"]
sample_steps = args["mochi_args"]["num_inference_steps"] sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule") cfg_schedule = args["mochi_args"].get("cfg_schedule")
@ -267,7 +234,7 @@ class T2VSynthMochiModel:
assert ( assert (
len(sigma_schedule) == sample_steps + 1 len(sigma_schedule) == sample_steps + 1
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}" assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
# if batch_cfg: # if batch_cfg:
# sample_batched = self.get_conditioning( # sample_batched = self.get_conditioning(
@ -277,15 +244,19 @@ class T2VSynthMochiModel:
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0) # sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0) # sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
# create z
spatial_downsample = 8 spatial_downsample = 8
temporal_downsample = 6 temporal_downsample = 6
latent_t = (t - 1) // temporal_downsample + 1
latent_w, latent_h = w // spatial_downsample, h // spatial_downsample
latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h)
in_channels = 12 in_channels = 12
B = 1
C = in_channels
T = (num_frames - 1) // temporal_downsample + 1
H = height // spatial_downsample
W = width // spatial_downsample
latent_dims = dict(lT=T, lW=W, lH=H)
z = torch.randn( z = torch.randn(
(B, in_channels, latent_t, latent_h, latent_w), (B, C, T, H, W),
device=self.device, device=self.device,
generator=generator, generator=generator,
dtype=torch.float32, dtype=torch.float32,
@ -307,22 +278,6 @@ class T2VSynthMochiModel:
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)] "y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
} }
# print(sample["y_mask"])
# print(type(sample["y_mask"]))
# print(sample["y_mask"][0].shape)
# print(sample["y_feat"])
# print(type(sample["y_feat"]))
# print(sample["y_feat"][0].shape)
# print(sample_null["y_mask"])
# print(type(sample_null["y_mask"]))
# print(sample_null["y_mask"][0].shape)
# print(sample_null["y_feat"])
# print(type(sample_null["y_feat"]))
# print(sample_null["y_feat"][0].shape)
sample["packed_indices"] = self.get_packed_indices( sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims sample["y_mask"], **latent_dims
) )
@ -331,8 +286,6 @@ class T2VSynthMochiModel:
) )
def model_fn(*, z, sigma, cfg_scale): def model_fn(*, z, sigma, cfg_scale):
#print("z", z.dtype, z.device)
#print("sigma", sigma.dtype, sigma.device)
self.dit.to(self.device) self.dit.to(self.device)
# if batch_cfg: # if batch_cfg:
# with torch.autocast("cuda", dtype=torch.bfloat16): # with torch.autocast("cuda", dtype=torch.bfloat16):
@ -341,7 +294,7 @@ class T2VSynthMochiModel:
#else: #else:
nonlocal sample, sample_null nonlocal sample, sample_null
with torch.autocast("cuda", dtype=torch.bfloat16): with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
out_cond = self.dit(z, sigma, **sample) out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null) out_uncond = self.dit(z, sigma, **sample_null)
assert out_cond.shape == out_uncond.shape assert out_cond.shape == out_uncond.shape
@ -364,8 +317,6 @@ class T2VSynthMochiModel:
pred = pred.to(z) pred = pred.to(z)
output_cond = output_cond.to(z) output_cond = output_cond.to(z)
#if stream_results:
# yield i / sample_steps, None, False
z = z + dsigma * pred z = z + dsigma * pred
comfy_pbar.update(1) comfy_pbar.update(1)

123
nodes.py
View File

@ -248,7 +248,7 @@ class MochiDecode:
"samples": ("LATENT", ), "samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}), "frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of frames in latent space (downscale factor is 6) to decode at once"}),
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}), "tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}), "tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
@ -268,6 +268,17 @@ class MochiDecode:
samples = samples["samples"] samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device) samples = samples.to(torch.bfloat16).to(device)
B, C, T, H, W = samples.shape
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent) blend_extent = min(a.shape[3], b.shape[3], blend_extent)
@ -284,70 +295,68 @@ class MochiDecode:
x / blend_extent x / blend_extent
) )
return b return b
def decode_tiled(samples):
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 # Split z into overlapping tiles and decode them separately.
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 # The tiles have an overlap to avoid seams between tiles.
comfy_pbar = ProgressBar(len(range(0, H, overlap_height)))
rows = []
for i in tqdm(range(0, H, overlap_height), desc="Processing rows"):
row = []
for j in tqdm(range(0, W, overlap_width), desc="Processing columns", leave=False):
time = []
for k in tqdm(range(T // frame_batch_size), desc="Processing frames", leave=False):
remaining_frames = T % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = samples[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tile = vae(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
comfy_pbar.update(1)
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8 result_rows = []
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8 for i, row in enumerate(tqdm(rows, desc="Blending rows")):
result_row = []
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
self.tile_latent_min_height = int(self.tile_sample_min_height / 8) return torch.cat(result_rows, dim=3)
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
vae.to(device) vae.to(device)
with torch.amp.autocast("cuda", dtype=torch.bfloat16): with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if not enable_vae_tiling: if enable_vae_tiling and frame_batch_size > T:
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
samples = vae(samples)
elif not enable_vae_tiling:
logging.warning("Attempting to decode without tiling, very memory intensive")
samples = vae(samples) samples = vae(samples)
else: else:
batch_size, num_channels, num_frames, height, width = samples.shape logging.info("Decoding with tiling")
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) samples = decode_tiled(samples)
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
rows = []
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
row = []
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
time = []
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = samples[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tile = vae(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
comfy_pbar.update(1)
result_rows = []
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
result_row = []
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
samples = torch.cat(result_rows, dim=3)
vae.to(offload_device) vae.to(offload_device)
#print("samples", samples.shape, samples.dtype, samples.device)
samples = samples.float() samples = samples.float()
samples = (samples + 1.0) / 2.0 samples = (samples + 1.0) / 2.0