mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
WAN: Fix cache VRAM leak on error (#10141)
If this suffers an exception (such as a VRAM oom) it will leave the encode() and decode() methods which skips the cleanup of the WAN feature cache. The comfy node cache then ultimately keeps a reference this object which is in turn reffing large tensors from the failed execution. The feature cache is currently setup at a class variable on the encoder/decoder however, the encode and decode functions always clear it on both entry and exit of normal execution. Its likely the design intent is this is usable as a streaming encoder where the input comes in batches, however the functions as they are today don't support that. So simplify by bringing the cache back to local variable, so that if it does VRAM OOM the cache itself is properly garbage when the encode()/decode() functions dissappear from the stack.
This commit is contained in:
parent
911331c06c
commit
4965c0e2ac
@ -468,55 +468,46 @@ class WanVAE(nn.Module):
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
self.clear_cache()
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
self.clear_cache()
|
||||
conv_idx = [0]
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
# z: [b,c,t,h,w]
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user