Cleanup, fix seed gen, better warnings for decoder
This commit is contained in:
parent
f714748ad4
commit
257c526125
@ -1,17 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
import random
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
try:
|
try:
|
||||||
@ -86,29 +83,6 @@ def compute_packed_indices(
|
|||||||
"valid_token_indices_kv": valid_token_indices,
|
"valid_token_indices_kv": valid_token_indices,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def shift_sigma(
|
|
||||||
sigma: np.ndarray,
|
|
||||||
shift: float,
|
|
||||||
):
|
|
||||||
"""Shift noise standard deviation toward higher values.
|
|
||||||
|
|
||||||
Useful for training a model at high resolutions,
|
|
||||||
or sampling more finely at high noise levels.
|
|
||||||
|
|
||||||
Equivalent to:
|
|
||||||
sigma_shift = shift / (shift + 1 / sigma - 1)
|
|
||||||
except for sigma = 0.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sigma: noise standard deviation in [0, 1]
|
|
||||||
shift: shift factor >= 1.
|
|
||||||
For shift > 1, shifts sigma to higher values.
|
|
||||||
For shift = 1, identity function.
|
|
||||||
"""
|
|
||||||
return shift * sigma / (shift * sigma + 1 - sigma)
|
|
||||||
|
|
||||||
|
|
||||||
class T2VSynthMochiModel:
|
class T2VSynthMochiModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -239,23 +213,16 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
@torch.inference_mode(mode=True)
|
@torch.inference_mode(mode=True)
|
||||||
def run(self, args, stream_results):
|
def run(self, args, stream_results):
|
||||||
random.seed(args["seed"])
|
|
||||||
np.random.seed(args["seed"])
|
|
||||||
torch.manual_seed(args["seed"])
|
torch.manual_seed(args["seed"])
|
||||||
|
torch.cuda.manual_seed(args["seed"])
|
||||||
|
|
||||||
generator = torch.Generator(device=self.device)
|
generator = torch.Generator(device=self.device)
|
||||||
generator.manual_seed(args["seed"])
|
generator.manual_seed(args["seed"])
|
||||||
|
|
||||||
# assert (
|
num_frames = args["num_frames"]
|
||||||
# len(args["prompt"]) == 1
|
height = args["height"]
|
||||||
# ), f"Expected exactly one prompt, got {len(args['prompt'])}"
|
width = args["width"]
|
||||||
#prompt = args["prompt"][0]
|
|
||||||
#neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else ""
|
|
||||||
B = 1
|
|
||||||
|
|
||||||
w = args["width"]
|
|
||||||
h = args["height"]
|
|
||||||
t = args["num_frames"]
|
|
||||||
batch_cfg = args["mochi_args"]["batch_cfg"]
|
batch_cfg = args["mochi_args"]["batch_cfg"]
|
||||||
sample_steps = args["mochi_args"]["num_inference_steps"]
|
sample_steps = args["mochi_args"]["num_inference_steps"]
|
||||||
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
||||||
@ -267,7 +234,7 @@ class T2VSynthMochiModel:
|
|||||||
assert (
|
assert (
|
||||||
len(sigma_schedule) == sample_steps + 1
|
len(sigma_schedule) == sample_steps + 1
|
||||||
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
|
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
|
||||||
assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}"
|
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
|
||||||
|
|
||||||
# if batch_cfg:
|
# if batch_cfg:
|
||||||
# sample_batched = self.get_conditioning(
|
# sample_batched = self.get_conditioning(
|
||||||
@ -277,15 +244,19 @@ class T2VSynthMochiModel:
|
|||||||
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
|
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
|
||||||
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
|
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
|
||||||
|
|
||||||
|
# create z
|
||||||
spatial_downsample = 8
|
spatial_downsample = 8
|
||||||
temporal_downsample = 6
|
temporal_downsample = 6
|
||||||
latent_t = (t - 1) // temporal_downsample + 1
|
|
||||||
latent_w, latent_h = w // spatial_downsample, h // spatial_downsample
|
|
||||||
|
|
||||||
latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h)
|
|
||||||
in_channels = 12
|
in_channels = 12
|
||||||
|
B = 1
|
||||||
|
C = in_channels
|
||||||
|
T = (num_frames - 1) // temporal_downsample + 1
|
||||||
|
H = height // spatial_downsample
|
||||||
|
W = width // spatial_downsample
|
||||||
|
latent_dims = dict(lT=T, lW=W, lH=H)
|
||||||
|
|
||||||
z = torch.randn(
|
z = torch.randn(
|
||||||
(B, in_channels, latent_t, latent_h, latent_w),
|
(B, C, T, H, W),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -307,22 +278,6 @@ class T2VSynthMochiModel:
|
|||||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||||
}
|
}
|
||||||
|
|
||||||
# print(sample["y_mask"])
|
|
||||||
# print(type(sample["y_mask"]))
|
|
||||||
# print(sample["y_mask"][0].shape)
|
|
||||||
|
|
||||||
# print(sample["y_feat"])
|
|
||||||
# print(type(sample["y_feat"]))
|
|
||||||
# print(sample["y_feat"][0].shape)
|
|
||||||
|
|
||||||
# print(sample_null["y_mask"])
|
|
||||||
# print(type(sample_null["y_mask"]))
|
|
||||||
# print(sample_null["y_mask"][0].shape)
|
|
||||||
|
|
||||||
# print(sample_null["y_feat"])
|
|
||||||
# print(type(sample_null["y_feat"]))
|
|
||||||
# print(sample_null["y_feat"][0].shape)
|
|
||||||
|
|
||||||
sample["packed_indices"] = self.get_packed_indices(
|
sample["packed_indices"] = self.get_packed_indices(
|
||||||
sample["y_mask"], **latent_dims
|
sample["y_mask"], **latent_dims
|
||||||
)
|
)
|
||||||
@ -331,8 +286,6 @@ class T2VSynthMochiModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def model_fn(*, z, sigma, cfg_scale):
|
def model_fn(*, z, sigma, cfg_scale):
|
||||||
#print("z", z.dtype, z.device)
|
|
||||||
#print("sigma", sigma.dtype, sigma.device)
|
|
||||||
self.dit.to(self.device)
|
self.dit.to(self.device)
|
||||||
# if batch_cfg:
|
# if batch_cfg:
|
||||||
# with torch.autocast("cuda", dtype=torch.bfloat16):
|
# with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||||
@ -341,7 +294,7 @@ class T2VSynthMochiModel:
|
|||||||
#else:
|
#else:
|
||||||
|
|
||||||
nonlocal sample, sample_null
|
nonlocal sample, sample_null
|
||||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
|
||||||
out_cond = self.dit(z, sigma, **sample)
|
out_cond = self.dit(z, sigma, **sample)
|
||||||
out_uncond = self.dit(z, sigma, **sample_null)
|
out_uncond = self.dit(z, sigma, **sample_null)
|
||||||
assert out_cond.shape == out_uncond.shape
|
assert out_cond.shape == out_uncond.shape
|
||||||
@ -364,8 +317,6 @@ class T2VSynthMochiModel:
|
|||||||
pred = pred.to(z)
|
pred = pred.to(z)
|
||||||
output_cond = output_cond.to(z)
|
output_cond = output_cond.to(z)
|
||||||
|
|
||||||
#if stream_results:
|
|
||||||
# yield i / sample_steps, None, False
|
|
||||||
z = z + dsigma * pred
|
z = z + dsigma * pred
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
|
|||||||
123
nodes.py
123
nodes.py
@ -248,7 +248,7 @@ class MochiDecode:
|
|||||||
"samples": ("LATENT", ),
|
"samples": ("LATENT", ),
|
||||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||||
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}),
|
"frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of frames in latent space (downscale factor is 6) to decode at once"}),
|
||||||
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
||||||
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
||||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
@ -268,6 +268,17 @@ class MochiDecode:
|
|||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.to(torch.bfloat16).to(device)
|
samples = samples.to(torch.bfloat16).to(device)
|
||||||
|
|
||||||
|
B, C, T, H, W = samples.shape
|
||||||
|
|
||||||
|
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
||||||
|
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
||||||
|
|
||||||
|
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
|
||||||
|
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
|
||||||
|
|
||||||
|
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
|
||||||
|
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
|
||||||
|
|
||||||
|
|
||||||
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||||
@ -284,70 +295,68 @@ class MochiDecode:
|
|||||||
x / blend_extent
|
x / blend_extent
|
||||||
)
|
)
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
def decode_tiled(samples):
|
||||||
|
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||||
|
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||||
|
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||||
|
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||||
|
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||||
|
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||||
|
|
||||||
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
# Split z into overlapping tiles and decode them separately.
|
||||||
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
comfy_pbar = ProgressBar(len(range(0, H, overlap_height)))
|
||||||
|
rows = []
|
||||||
|
for i in tqdm(range(0, H, overlap_height), desc="Processing rows"):
|
||||||
|
row = []
|
||||||
|
for j in tqdm(range(0, W, overlap_width), desc="Processing columns", leave=False):
|
||||||
|
time = []
|
||||||
|
for k in tqdm(range(T // frame_batch_size), desc="Processing frames", leave=False):
|
||||||
|
remaining_frames = T % frame_batch_size
|
||||||
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||||
|
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||||
|
tile = samples[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
start_frame:end_frame,
|
||||||
|
i : i + self.tile_latent_min_height,
|
||||||
|
j : j + self.tile_latent_min_width,
|
||||||
|
]
|
||||||
|
tile = vae(tile)
|
||||||
|
time.append(tile)
|
||||||
|
row.append(torch.cat(time, dim=2))
|
||||||
|
rows.append(row)
|
||||||
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
|
result_rows = []
|
||||||
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
|
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||||
|
if j > 0:
|
||||||
|
tile = blend_h(row[j - 1], tile, blend_extent_width)
|
||||||
|
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=4))
|
||||||
|
|
||||||
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
|
return torch.cat(result_rows, dim=3)
|
||||||
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
|
|
||||||
|
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||||
if not enable_vae_tiling:
|
if enable_vae_tiling and frame_batch_size > T:
|
||||||
|
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
|
||||||
|
samples = vae(samples)
|
||||||
|
elif not enable_vae_tiling:
|
||||||
|
logging.warning("Attempting to decode without tiling, very memory intensive")
|
||||||
samples = vae(samples)
|
samples = vae(samples)
|
||||||
else:
|
else:
|
||||||
batch_size, num_channels, num_frames, height, width = samples.shape
|
logging.info("Decoding with tiling")
|
||||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
samples = decode_tiled(samples)
|
||||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
|
||||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
|
||||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
|
||||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
|
||||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
|
||||||
|
|
||||||
# Split z into overlapping tiles and decode them separately.
|
|
||||||
# The tiles have an overlap to avoid seams between tiles.
|
|
||||||
comfy_pbar = ProgressBar(len(range(0, height, overlap_height)))
|
|
||||||
rows = []
|
|
||||||
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
|
||||||
row = []
|
|
||||||
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
|
|
||||||
time = []
|
|
||||||
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
|
|
||||||
remaining_frames = num_frames % frame_batch_size
|
|
||||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
|
||||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
|
||||||
tile = samples[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
start_frame:end_frame,
|
|
||||||
i : i + self.tile_latent_min_height,
|
|
||||||
j : j + self.tile_latent_min_width,
|
|
||||||
]
|
|
||||||
tile = vae(tile)
|
|
||||||
time.append(tile)
|
|
||||||
row.append(torch.cat(time, dim=2))
|
|
||||||
rows.append(row)
|
|
||||||
comfy_pbar.update(1)
|
|
||||||
|
|
||||||
result_rows = []
|
|
||||||
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
|
||||||
result_row = []
|
|
||||||
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
|
|
||||||
# blend the above tile and the left tile
|
|
||||||
# to the current tile and add the current tile to the result row
|
|
||||||
if i > 0:
|
|
||||||
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
|
|
||||||
if j > 0:
|
|
||||||
tile = blend_h(row[j - 1], tile, blend_extent_width)
|
|
||||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
|
||||||
result_rows.append(torch.cat(result_row, dim=4))
|
|
||||||
|
|
||||||
samples = torch.cat(result_rows, dim=3)
|
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
#print("samples", samples.shape, samples.dtype, samples.device)
|
|
||||||
|
|
||||||
samples = samples.float()
|
samples = samples.float()
|
||||||
samples = (samples + 1.0) / 2.0
|
samples = (samples + 1.0) / 2.0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user