mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-16 01:25:08 +08:00
* Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options * Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added * Fix memory usage issue with inspect * Made WAN attention receive transformer_options, test node added to wan to test out attention override later * Added **kwargs to all attention functions so transformer_options could potentially be passed through * Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention * Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) * Make flux work with optimized_attention_override * Add logs to verify optimized_attention_override is passed all the way into attention function * Make Qwen work with optimized_attention_override * Made hidream work with optimized_attention_override * Made wan patches_replace work with optimized_attention_override * Made SD3 work with optimized_attention_override * Made HunyuanVideo work with optimized_attention_override * Made Mochi work with optimized_attention_override * Made LTX work with optimized_attention_override * Made StableAudio work with optimized_attention_override * Made optimized_attention_override work with ACE Step * Made Hunyuan3D work with optimized_attention_override * Make CosmosPredict2 work with optimized_attention_override * Made CosmosVideo work with optimized_attention_override * Made Omnigen 2 work with optimized_attention_override * Made StableCascade work with optimized_attention_override * Made AuraFlow work with optimized_attention_override * Made Lumina work with optimized_attention_override * Made Chroma work with optimized_attention_override * Made SVD work with optimized_attention_override * Fix WanI2VCrossAttention so that it expects to receive transformer_options * Fixed Wan2.1 Fun Camera transformer_options passthrough * Fixed WAN 2.1 VACE transformer_options passthrough * Add optimized to get_attention_function * Disable attention logs for now * Remove attention logging code * Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case * Satisfy ruff * Remove AttentionOverrideTest node, that's something to cook up for later
46 lines
1.9 KiB
Python
46 lines
1.9 KiB
Python
import torch
|
|
from einops import rearrange
|
|
from torch import Tensor
|
|
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
import comfy.model_management
|
|
|
|
|
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
|
q_shape = q.shape
|
|
k_shape = k.shape
|
|
|
|
if pe is not None:
|
|
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
|
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
|
|
|
heads = q.shape[1]
|
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
|
return x
|
|
|
|
|
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|
assert dim % 2 == 0
|
|
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
|
device = torch.device("cpu")
|
|
else:
|
|
device = pos.device
|
|
|
|
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
|
omega = 1.0 / (theta**scale)
|
|
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
|
return out.to(dtype=torch.float32, device=pos.device)
|
|
|
|
|
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
|
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
|