mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 14:04:26 +08:00
If this suffers an exception (such as a VRAM oom) it will leave the encode() and decode() methods which skips the cleanup of the WAN feature cache. The comfy node cache then ultimately keeps a reference this object which is in turn reffing large tensors from the failed execution. The feature cache is currently setup at a class variable on the encoder/decoder however, the encode and decode functions always clear it on both entry and exit of normal execution. Its likely the design intent is this is usable as a streaming encoder where the input comes in batches, however the functions as they are today don't support that. So simplify by bringing the cache back to local variable, so that if it does VRAM OOM the cache itself is properly garbage when the encode()/decode() functions dissappear from the stack.
514 lines
18 KiB
Python
514 lines
18 KiB
Python
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||
|
||
import comfy.ops
|
||
ops = comfy.ops.disable_weight_init
|
||
|
||
CACHE_T = 2
|
||
|
||
|
||
class CausalConv3d(ops.Conv3d):
|
||
"""
|
||
Causal 3d convolusion.
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||
self.padding[1], 2 * self.padding[0], 0)
|
||
self.padding = (0, 0, 0)
|
||
|
||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||
if cache_list is not None:
|
||
cache_x = cache_list[cache_idx]
|
||
cache_list[cache_idx] = None
|
||
|
||
padding = list(self._padding)
|
||
if cache_x is not None and self._padding[4] > 0:
|
||
cache_x = cache_x.to(x.device)
|
||
x = torch.cat([cache_x, x], dim=2)
|
||
padding[4] -= cache_x.shape[2]
|
||
del cache_x
|
||
x = F.pad(x, padding)
|
||
|
||
return super().forward(x)
|
||
|
||
|
||
class RMS_norm(nn.Module):
|
||
|
||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||
super().__init__()
|
||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||
|
||
self.channel_first = channel_first
|
||
self.scale = dim**0.5
|
||
self.gamma = nn.Parameter(torch.ones(shape))
|
||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||
|
||
def forward(self, x):
|
||
return F.normalize(
|
||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||
|
||
|
||
class Resample(nn.Module):
|
||
|
||
def __init__(self, dim, mode):
|
||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||
'downsample3d')
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.mode = mode
|
||
|
||
# layers
|
||
if mode == 'upsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
elif mode == 'upsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||
|
||
elif mode == 'downsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
elif mode == 'downsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||
|
||
else:
|
||
self.resample = nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
b, c, t, h, w = x.size()
|
||
if self.mode == 'upsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = 'Rep'
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] != 'Rep':
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] == 'Rep':
|
||
cache_x = torch.cat([
|
||
torch.zeros_like(cache_x).to(cache_x.device),
|
||
cache_x
|
||
],
|
||
dim=2)
|
||
if feat_cache[idx] == 'Rep':
|
||
x = self.time_conv(x)
|
||
else:
|
||
x = self.time_conv(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
|
||
x = x.reshape(b, 2, c, t, h, w)
|
||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||
3)
|
||
x = x.reshape(b, c, t * 2, h, w)
|
||
t = x.shape[2]
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.resample(x)
|
||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||
|
||
if self.mode == 'downsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = x.clone()
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -1:, :, :].clone()
|
||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||
# # cache last frame of last two chunk
|
||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||
|
||
x = self.time_conv(
|
||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
return x
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||
super().__init__()
|
||
self.in_dim = in_dim
|
||
self.out_dim = out_dim
|
||
|
||
# layers
|
||
self.residual = nn.Sequential(
|
||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||
if in_dim != out_dim else nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
old_x = x
|
||
for layer in self.residual:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x + self.shortcut(old_x)
|
||
|
||
|
||
class AttentionBlock(nn.Module):
|
||
"""
|
||
Causal self-attention with a single head.
|
||
"""
|
||
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
# layers
|
||
self.norm = RMS_norm(dim)
|
||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||
self.proj = ops.Conv2d(dim, dim, 1)
|
||
self.optimized_attention = vae_attention()
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
b, c, t, h, w = x.size()
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.norm(x)
|
||
# compute query, key, value
|
||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||
x = self.optimized_attention(q, k, v)
|
||
|
||
# output
|
||
x = self.proj(x)
|
||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||
return x + identity
|
||
|
||
|
||
class Encoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [1] + dim_mult]
|
||
scale = 1.0
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||
|
||
# downsample blocks
|
||
downsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
for _ in range(num_res_blocks):
|
||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
downsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# downsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'downsample3d' if temperal_downsample[
|
||
i] else 'downsample2d'
|
||
downsamples.append(Resample(out_dim, mode=mode))
|
||
scale /= 2.0
|
||
self.downsamples = nn.Sequential(*downsamples)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||
ResidualBlock(out_dim, out_dim, dropout))
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## downsamples
|
||
for layer in self.downsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
class Decoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_upsample=[False, True, True],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_upsample = temperal_upsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||
ResidualBlock(dims[0], dims[0], dropout))
|
||
|
||
# upsample blocks
|
||
upsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
if i == 1 or i == 2 or i == 3:
|
||
in_dim = in_dim // 2
|
||
for _ in range(num_res_blocks + 1):
|
||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
upsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# upsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||
upsamples.append(Resample(out_dim, mode=mode))
|
||
scale *= 2.0
|
||
self.upsamples = nn.Sequential(*upsamples)
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
## conv1
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## upsamples
|
||
for layer in self.upsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
def count_conv3d(model):
|
||
count = 0
|
||
for m in model.modules():
|
||
if isinstance(m, CausalConv3d):
|
||
count += 1
|
||
return count
|
||
|
||
|
||
class WanVAE(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
self.temperal_upsample = temperal_downsample[::-1]
|
||
|
||
# modules
|
||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_downsample, dropout)
|
||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_upsample, dropout)
|
||
|
||
def encode(self, x):
|
||
conv_idx = [0]
|
||
feat_map = [None] * count_conv3d(self.decoder)
|
||
## cache
|
||
t = x.shape[2]
|
||
iter_ = 1 + (t - 1) // 4
|
||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.encoder(
|
||
x[:, :, :1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.encoder(
|
||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||
return mu
|
||
|
||
def decode(self, z):
|
||
conv_idx = [0]
|
||
feat_map = [None] * count_conv3d(self.decoder)
|
||
# z: [b,c,t,h,w]
|
||
|
||
iter_ = z.shape[2]
|
||
x = self.conv2(z)
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
return out
|