mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-11 23:14:25 +08:00
Reduce Peak WAN inference VRAM usage (#9898)
* flux: Do the xq and xk ropes one at a time This was doing independendent interleaved tensor math on the q and k tensors, leading to the holding of more than the minimum intermediates in VRAM. On a bad day, it would VRAM OOM on xk intermediates. Do everything q and then everything k, so torch can garbage collect all of qs intermediates before k allocates its intermediates. This reduces peak VRAM usage for some WAN2.2 inferences (at least). * wan: Optimize qkv intermediates on attention As commented. The former logic computed independent pieces of QKV in parallel which help more inference intermediates in VRAM spiking VRAM usage. Fully roping Q and garbage collecting the intermediates before touching K reduces the peak inference VRAM usage.
This commit is contained in:
parent
a39ac59c3e
commit
e42682b24e
@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from einops import rearrange
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -60,20 +60,24 @@ class WanSelfAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
# query, key, value function
|
def qkv_fn_q(x):
|
||||||
def qkv_fn(x):
|
|
||||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
return apply_rope1(q, freqs)
|
||||||
v = self.v(x).view(b, s, n * d)
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
def qkv_fn_k(x):
|
||||||
q, k = apply_rope(q, k, freqs)
|
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||||
|
return apply_rope1(k, freqs)
|
||||||
|
|
||||||
|
#These two are VRAM hogs, so we want to do all of q computation and
|
||||||
|
#have pytorch garbage collect the intermediates on the sub function
|
||||||
|
#return before we touch k
|
||||||
|
q = qkv_fn_q(x)
|
||||||
|
k = qkv_fn_k(x)
|
||||||
|
|
||||||
x = optimized_attention(
|
x = optimized_attention(
|
||||||
q.view(b, s, n * d),
|
q.view(b, s, n * d),
|
||||||
k.view(b, s, n * d),
|
k.view(b, s, n * d),
|
||||||
v,
|
self.v(x).view(b, s, n * d),
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user