From 653ceab4148a9fbc050ebceb674acef760792b77 Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Sun, 28 Sep 2025 08:14:16 +1000 Subject: [PATCH] Reduce Peak WAN inference VRAM usage - part II (#10062) * flux: math: Use _addcmul to avoid expensive VRAM intermediate The rope process can be the VRAM peak and this intermediate for the addition result before releasing the original can OOM. addcmul_ it. * wan: Delete the self attention before cross attention This saves VRAM when the cross attention and FFN are in play as the VRAM peak. --- comfy/ldm/flux/math.py | 5 ++++- comfy/ldm/wan/model.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index fb7cd7586..8deda0d4a 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 54616e6eb..0dc650ced 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module): freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)