mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 14:04:26 +08:00
* flux: math: Use _addcmul to avoid expensive VRAM intermediate The rope process can be the VRAM peak and this intermediate for the addition result before releasing the original can OOM. addcmul_ it. * wan: Delete the self attention before cross attention This saves VRAM when the cross attention and FFN are in play as the VRAM peak.
48 lines
1.9 KiB
Python
48 lines
1.9 KiB
Python
import torch
|
|
from einops import rearrange
|
|
from torch import Tensor
|
|
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
import comfy.model_management
|
|
|
|
|
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
|
q_shape = q.shape
|
|
k_shape = k.shape
|
|
|
|
if pe is not None:
|
|
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
|
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
|
|
|
heads = q.shape[1]
|
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
|
return x
|
|
|
|
|
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|
assert dim % 2 == 0
|
|
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
|
device = torch.device("cpu")
|
|
else:
|
|
device = pos.device
|
|
|
|
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
|
omega = 1.0 / (theta**scale)
|
|
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
|
return out.to(dtype=torch.float32, device=pos.device)
|
|
|
|
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
|
|
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
|
|
|
return x_out.reshape(*x.shape).type_as(x)
|
|
|
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|