mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
* init * update * Update model.py * Update model.py * remove print * Fix text encoding * Prevent empty negative prompt Really doesn't work otherwise * fp16 works * I2V * Update model_base.py * Update nodes_hunyuan.py * Better latent rgb factors * Use the correct sigclip output... * Support HunyuanVideo1.5 SR model * whitespaces... * Proper latent channel count * SR model fixes This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already * vae_refiner: roll the convolution through temporal Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. * Support HunyuanVideo15 latent resampler * fix * Some cleanup Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Proper hyvid15 I2V channels Co-Authored-By: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Fix TokenRefiner for fp16 Otherwise x.sum has infs, just in case only casting if input is fp16, I don't know if necessary. * Bugfix for the HunyuanVideo15 SR model * vae_refiner: roll the convolution through temporal II Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. Added support for encoder, lowered to 1 latent frame to save more VRAM, made work for Hunyuan Image 3.0 (as code shared). Fixed names, cleaned up code. * Allow any number of input frames in VAE. * Better VAE encode mem estimation. * Lowvram fix. * Fix hunyuan image 2.1 refiner. * Fix mistake. * Name changes. * Rename. * Whitespace. * Fix. * Fix. --------- Co-authored-by: kijai <40791699+kijai@users.noreply.github.com> Co-authored-by: Rattus <rattus128@gmail.com>
121 lines
4.0 KiB
Python
121 lines
4.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
|
import model_management, model_patcher
|
|
|
|
class SRResidualCausalBlock3D(nn.Module):
|
|
def __init__(self, channels: int):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
VideoConv3d(channels, channels, kernel_size=3),
|
|
nn.SiLU(inplace=True),
|
|
VideoConv3d(channels, channels, kernel_size=3),
|
|
nn.SiLU(inplace=True),
|
|
VideoConv3d(channels, channels, kernel_size=3),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.block(x)
|
|
|
|
class SRModel3DV2(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
hidden_channels: int = 64,
|
|
num_blocks: int = 6,
|
|
global_residual: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
|
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
|
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
|
self.global_residual = bool(global_residual)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
residual = x
|
|
y = self.in_conv(x)
|
|
for blk in self.blocks:
|
|
y = blk(y)
|
|
y = self.out_conv(y)
|
|
if self.global_residual and (y.shape == residual.shape):
|
|
y = y + residual
|
|
return y
|
|
|
|
|
|
class Upsampler(nn.Module):
|
|
def __init__(
|
|
self,
|
|
z_channels: int,
|
|
out_channels: int,
|
|
block_out_channels: tuple[int, ...],
|
|
num_res_blocks: int = 2,
|
|
):
|
|
super().__init__()
|
|
self.num_res_blocks = num_res_blocks
|
|
self.block_out_channels = block_out_channels
|
|
self.z_channels = z_channels
|
|
|
|
ch = block_out_channels[0]
|
|
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
|
|
|
self.up = nn.ModuleList()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_shortcut=False,
|
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
|
for j in range(num_res_blocks + 1)])
|
|
ch = tgt
|
|
self.up.append(stage)
|
|
|
|
self.norm_out = RMS_norm(ch)
|
|
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
|
|
|
def forward(self, z):
|
|
"""
|
|
Args:
|
|
z: (B, C, T, H, W)
|
|
target_shape: (H, W)
|
|
"""
|
|
# z to block_in
|
|
repeats = self.block_out_channels[0] // (self.z_channels)
|
|
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
|
|
|
# upsampling
|
|
for stage in self.up:
|
|
for blk in stage.block:
|
|
x = blk(x)
|
|
|
|
out = self.conv_out(F.silu(self.norm_out(x)))
|
|
return out
|
|
|
|
UPSAMPLERS = {
|
|
"720p": SRModel3DV2,
|
|
"1080p": Upsampler,
|
|
}
|
|
|
|
class HunyuanVideo15SRModel():
|
|
def __init__(self, model_type, config):
|
|
self.load_device = model_management.vae_device()
|
|
offload_device = model_management.vae_offload_device()
|
|
self.dtype = model_management.vae_dtype(self.load_device)
|
|
self.model_class = UPSAMPLERS.get(model_type)
|
|
self.model = self.model_class(**config).eval()
|
|
|
|
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
|
|
|
def load_sd(self, sd):
|
|
return self.model.load_state_dict(sd, strict=True)
|
|
|
|
def get_sd(self):
|
|
return self.model.state_dict()
|
|
|
|
def resample_latent(self, latent):
|
|
model_management.load_model_gpu(self.patcher)
|
|
return self.model(latent.to(self.load_device))
|