Cleanup, fix seed gen, better warnings for decoder
This commit is contained in:
parent
f714748ad4
commit
257c526125
@ -1,17 +1,14 @@
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from torch import nn
|
||||
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
import comfy.model_management as mm
|
||||
|
||||
from contextlib import nullcontext
|
||||
try:
|
||||
@ -86,29 +83,6 @@ def compute_packed_indices(
|
||||
"valid_token_indices_kv": valid_token_indices,
|
||||
}
|
||||
|
||||
|
||||
def shift_sigma(
|
||||
sigma: np.ndarray,
|
||||
shift: float,
|
||||
):
|
||||
"""Shift noise standard deviation toward higher values.
|
||||
|
||||
Useful for training a model at high resolutions,
|
||||
or sampling more finely at high noise levels.
|
||||
|
||||
Equivalent to:
|
||||
sigma_shift = shift / (shift + 1 / sigma - 1)
|
||||
except for sigma = 0.
|
||||
|
||||
Args:
|
||||
sigma: noise standard deviation in [0, 1]
|
||||
shift: shift factor >= 1.
|
||||
For shift > 1, shifts sigma to higher values.
|
||||
For shift = 1, identity function.
|
||||
"""
|
||||
return shift * sigma / (shift * sigma + 1 - sigma)
|
||||
|
||||
|
||||
class T2VSynthMochiModel:
|
||||
def __init__(
|
||||
self,
|
||||
@ -239,23 +213,16 @@ class T2VSynthMochiModel:
|
||||
|
||||
@torch.inference_mode(mode=True)
|
||||
def run(self, args, stream_results):
|
||||
random.seed(args["seed"])
|
||||
np.random.seed(args["seed"])
|
||||
torch.manual_seed(args["seed"])
|
||||
torch.cuda.manual_seed(args["seed"])
|
||||
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(args["seed"])
|
||||
|
||||
# assert (
|
||||
# len(args["prompt"]) == 1
|
||||
# ), f"Expected exactly one prompt, got {len(args['prompt'])}"
|
||||
#prompt = args["prompt"][0]
|
||||
#neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else ""
|
||||
B = 1
|
||||
|
||||
w = args["width"]
|
||||
h = args["height"]
|
||||
t = args["num_frames"]
|
||||
num_frames = args["num_frames"]
|
||||
height = args["height"]
|
||||
width = args["width"]
|
||||
|
||||
batch_cfg = args["mochi_args"]["batch_cfg"]
|
||||
sample_steps = args["mochi_args"]["num_inference_steps"]
|
||||
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
||||
@ -267,7 +234,7 @@ class T2VSynthMochiModel:
|
||||
assert (
|
||||
len(sigma_schedule) == sample_steps + 1
|
||||
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
|
||||
assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}"
|
||||
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
|
||||
|
||||
# if batch_cfg:
|
||||
# sample_batched = self.get_conditioning(
|
||||
@ -277,15 +244,19 @@ class T2VSynthMochiModel:
|
||||
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
|
||||
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
|
||||
|
||||
# create z
|
||||
spatial_downsample = 8
|
||||
temporal_downsample = 6
|
||||
latent_t = (t - 1) // temporal_downsample + 1
|
||||
latent_w, latent_h = w // spatial_downsample, h // spatial_downsample
|
||||
|
||||
latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h)
|
||||
in_channels = 12
|
||||
B = 1
|
||||
C = in_channels
|
||||
T = (num_frames - 1) // temporal_downsample + 1
|
||||
H = height // spatial_downsample
|
||||
W = width // spatial_downsample
|
||||
latent_dims = dict(lT=T, lW=W, lH=H)
|
||||
|
||||
z = torch.randn(
|
||||
(B, in_channels, latent_t, latent_h, latent_w),
|
||||
(B, C, T, H, W),
|
||||
device=self.device,
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
@ -307,22 +278,6 @@ class T2VSynthMochiModel:
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
|
||||
# print(sample["y_mask"])
|
||||
# print(type(sample["y_mask"]))
|
||||
# print(sample["y_mask"][0].shape)
|
||||
|
||||
# print(sample["y_feat"])
|
||||
# print(type(sample["y_feat"]))
|
||||
# print(sample["y_feat"][0].shape)
|
||||
|
||||
# print(sample_null["y_mask"])
|
||||
# print(type(sample_null["y_mask"]))
|
||||
# print(sample_null["y_mask"][0].shape)
|
||||
|
||||
# print(sample_null["y_feat"])
|
||||
# print(type(sample_null["y_feat"]))
|
||||
# print(sample_null["y_feat"][0].shape)
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
@ -331,8 +286,6 @@ class T2VSynthMochiModel:
|
||||
)
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
#print("z", z.dtype, z.device)
|
||||
#print("sigma", sigma.dtype, sigma.device)
|
||||
self.dit.to(self.device)
|
||||
# if batch_cfg:
|
||||
# with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
@ -341,7 +294,7 @@ class T2VSynthMochiModel:
|
||||
#else:
|
||||
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
assert out_cond.shape == out_uncond.shape
|
||||
@ -364,8 +317,6 @@ class T2VSynthMochiModel:
|
||||
pred = pred.to(z)
|
||||
output_cond = output_cond.to(z)
|
||||
|
||||
#if stream_results:
|
||||
# yield i / sample_steps, None, False
|
||||
z = z + dsigma * pred
|
||||
comfy_pbar.update(1)
|
||||
|
||||
|
||||
123
nodes.py
123
nodes.py
@ -248,7 +248,7 @@ class MochiDecode:
|
||||
"samples": ("LATENT", ),
|
||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
|
||||
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of frames in latent space (downscale factor is 6) to decode at once"}),
|
||||
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
||||
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
@ -268,6 +268,17 @@ class MochiDecode:
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(torch.bfloat16).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
||||
|
||||
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
|
||||
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
|
||||
|
||||
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
|
||||
|
||||
|
||||
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
@ -284,70 +295,68 @@ class MochiDecode:
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def decode_tiled(samples):
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
comfy_pbar = ProgressBar(len(range(0, H, overlap_height)))
|
||||
rows = []
|
||||
for i in tqdm(range(0, H, overlap_height), desc="Processing rows"):
|
||||
row = []
|
||||
for j in tqdm(range(0, W, overlap_width), desc="Processing columns", leave=False):
|
||||
time = []
|
||||
for k in tqdm(range(T // frame_batch_size), desc="Processing frames", leave=False):
|
||||
remaining_frames = T % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = samples[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
tile = vae(tile)
|
||||
time.append(tile)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
comfy_pbar.update(1)
|
||||
|
||||
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
|
||||
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
|
||||
result_rows = []
|
||||
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||
result_row = []
|
||||
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
|
||||
return torch.cat(result_rows, dim=3)
|
||||
|
||||
vae.to(device)
|
||||
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
||||
if not enable_vae_tiling:
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||
if enable_vae_tiling and frame_batch_size > T:
|
||||
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
|
||||
samples = vae(samples)
|
||||
elif not enable_vae_tiling:
|
||||
logging.warning("Attempting to decode without tiling, very memory intensive")
|
||||
samples = vae(samples)
|
||||
else:
|
||||
batch_size, num_channels, num_frames, height, width = samples.shape
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
|
||||
rows = []
|
||||
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
||||
row = []
|
||||
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
|
||||
time = []
|
||||
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = samples[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
tile = vae(tile)
|
||||
time.append(tile)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
comfy_pbar.update(1)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||
result_row = []
|
||||
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
samples = torch.cat(result_rows, dim=3)
|
||||
logging.info("Decoding with tiling")
|
||||
samples = decode_tiled(samples)
|
||||
|
||||
vae.to(offload_device)
|
||||
#print("samples", samples.shape, samples.dtype, samples.device)
|
||||
|
||||
samples = samples.float()
|
||||
samples = (samples + 1.0) / 2.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user