From e42682b24ef033a93001ba27cc5c5aa461a61d8d Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:21:14 +1000 Subject: [PATCH] Reduce Peak WAN inference VRAM usage (#9898) * flux: Do the xq and xk ropes one at a time This was doing independendent interleaved tensor math on the q and k tensors, leading to the holding of more than the minimum intermediates in VRAM. On a bad day, it would VRAM OOM on xk intermediates. Do everything q and then everything k, so torch can garbage collect all of qs intermediates before k allocates its intermediates. This reduces peak VRAM usage for some WAN2.2 inferences (at least). * wan: Optimize qkv intermediates on attention As commented. The former logic computed independent pieces of QKV in parallel which help more inference intermediates in VRAM spiking VRAM usage. Fully roping Q and garbage collecting the intermediates before touching K reduces the peak inference VRAM usage. --- comfy/ldm/flux/math.py | 11 +++++------ comfy/ldm/wan/model.py | 22 +++++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 4d743cda2..fb7cd7586 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 63472ada2..67dcf8f1e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope1 import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -60,20 +60,24 @@ class WanSelfAttention(nn.Module): """ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - # query, key, value function - def qkv_fn(x): + def qkv_fn_q(x): q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x).view(b, s, n * d) - return q, k, v + return apply_rope1(q, freqs) - q, k, v = qkv_fn(x) - q, k = apply_rope(q, k, freqs) + def qkv_fn_k(x): + k = self.norm_k(self.k(x)).view(b, s, n, d) + return apply_rope1(k, freqs) + + #These two are VRAM hogs, so we want to do all of q computation and + #have pytorch garbage collect the intermediates on the sub function + #return before we touch k + q = qkv_fn_q(x) + k = qkv_fn_k(x) x = optimized_attention( q.view(b, s, n * d), k.view(b, s, n * d), - v, + self.v(x).view(b, s, n * d), heads=self.num_heads, transformer_options=transformer_options, )