mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
flux: reduce VRAM usage (#10737)
Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22 for 1600x1600 on RTX5090.
This commit is contained in:
parent
2fde9597f4
commit
94c298f962
@ -167,39 +167,55 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
del img_modulated
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del img_qkv
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
del txt_modulated
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
if self.flipped_img_txt:
|
||||||
|
q = torch.cat((img_q, txt_q), dim=2)
|
||||||
|
del img_q, txt_q
|
||||||
|
k = torch.cat((img_k, txt_k), dim=2)
|
||||||
|
del img_k, txt_k
|
||||||
|
v = torch.cat((img_v, txt_v), dim=2)
|
||||||
|
del img_v, txt_v
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
attn = attention(q, k, v,
|
||||||
torch.cat((img_k, txt_k), dim=2),
|
|
||||||
torch.cat((img_v, txt_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||||
else:
|
else:
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
del txt_q, img_q
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
del txt_k, img_k
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
del txt_v, img_v
|
||||||
# run actual attention
|
# run actual attention
|
||||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
attn = attention(q, k, v,
|
||||||
torch.cat((txt_k, img_k), dim=2),
|
|
||||||
torch.cat((txt_v, img_v), dim=2),
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
|
del img_attn
|
||||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
|
del txt_attn
|
||||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
@ -249,12 +265,15 @@ class SingleStreamBlock(nn.Module):
|
|||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
|
del qkv
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
|
del q, k, v
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
mlp = self.mlp_act(mlp)
|
||||||
|
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user