support sageattn for mesh gen

This commit is contained in:
kijai 2025-01-25 19:56:54 +02:00
parent d454e92b9b
commit a86f12236e
4 changed files with 71 additions and 17 deletions

View File

@ -31,11 +31,21 @@ from einops import rearrange
from torch import Tensor, nn
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
try:
from sageattention import sageattn
except ImportError:
sageattn = None
def attention_sage(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
x = sageattn(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
@ -151,7 +161,12 @@ class DoubleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
attention_mode: str = "sdpa",
):
if attention_mode == "sdpa":
self.attention_func = attention
elif attention_mode == "sageattn":
self.attention_func = attention_sage
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
@ -198,7 +213,7 @@ class DoubleStreamBlock(nn.Module):
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
attn = self.attention_func(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
@ -221,8 +236,13 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
attention_mode: str = "sdpa",
):
super().__init__()
if attention_mode == "sdpa":
self.attention_func = attention
elif attention_mode == "sageattn":
self.attention_func = attention_sage
self.hidden_dim = hidden_size
self.num_heads = num_heads
@ -253,7 +273,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = self.attention_func(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
@ -288,6 +308,7 @@ class Hunyuan3DDiT(nn.Module):
qkv_bias: bool = True,
time_factor: float = 1000,
ckpt_path: Optional[str] = None,
attention_mode: str = "sdpa",
**kwargs,
):
super().__init__()
@ -303,6 +324,7 @@ class Hunyuan3DDiT(nn.Module):
self.qkv_bias = qkv_bias
self.time_factor = time_factor
self.out_channels = self.in_channels
self.attention_mode = attention_mode
if hidden_size % num_heads != 0:
raise ValueError(
@ -324,6 +346,7 @@ class Hunyuan3DDiT(nn.Module):
self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attention_mode=self.attention_mode,
)
for _ in range(depth)
]
@ -335,6 +358,7 @@ class Hunyuan3DDiT(nn.Module):
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
attention_mode=self.attention_mode,
)
for _ in range(depth_single_blocks)
]
@ -342,6 +366,8 @@ class Hunyuan3DDiT(nn.Module):
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
if ckpt_path is not None:
print('restored denoiser ckpt', ckpt_path)

View File

@ -31,7 +31,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from skimage import measure
from tqdm import tqdm
from sageattention import sageattn
from comfy.utils import ProgressBar
class FourierEmbedder(nn.Module):
@ -189,13 +189,18 @@ class QKVMultiheadCrossAttention(nn.Module):
n_data: Optional[int] = None,
width=None,
qk_norm=False,
norm_layer=nn.LayerNorm
norm_layer=nn.LayerNorm,
attention_mode: str = "sdpa"
):
super().__init__()
self.heads = heads
self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
if attention_mode == "sdpa":
self.attention_func = F.scaled_dot_product_attention
elif attention_mode == "sageattn":
self.attention_func = sageattn
def forward(self, q, kv):
_, n_ctx, _ = q.shape
@ -209,7 +214,7 @@ class QKVMultiheadCrossAttention(nn.Module):
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@ -224,7 +229,8 @@ class MultiheadCrossAttention(nn.Module):
n_data: Optional[int] = None,
data_width: Optional[int] = None,
norm_layer=nn.LayerNorm,
qk_norm: bool = False
qk_norm: bool = False,
attention_mode: str = "sdpa"
):
super().__init__()
self.n_data = n_data
@ -239,7 +245,8 @@ class MultiheadCrossAttention(nn.Module):
n_data=n_data,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
qk_norm=qk_norm,
attention_mode=attention_mode
)
def forward(self, x, data):
@ -260,12 +267,15 @@ class ResidualCrossAttentionBlock(nn.Module):
data_width: Optional[int] = None,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False
qk_norm: bool = False,
attention_mode: str = "sdpa"
):
super().__init__()
if data_width is None:
data_width = width
self.attention_mode = attention_mode
self.attn = MultiheadCrossAttention(
n_data=n_data,
@ -274,7 +284,8 @@ class ResidualCrossAttentionBlock(nn.Module):
data_width=data_width,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm
qk_norm=qk_norm,
attention_mode=self.attention_mode
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
@ -295,13 +306,18 @@ class QKVMultiheadAttention(nn.Module):
n_ctx: int,
width=None,
qk_norm=False,
norm_layer=nn.LayerNorm
norm_layer=nn.LayerNorm,
attention_mode: str = "sdpa"
):
super().__init__()
self.heads = heads
self.n_ctx = n_ctx
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
if attention_mode == "sdpa":
self.attention_func = F.scaled_dot_product_attention
elif attention_mode == "sageattn":
self.attention_func = sageattn
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
@ -313,7 +329,7 @@ class QKVMultiheadAttention(nn.Module):
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@ -433,7 +449,8 @@ class CrossAttentionDecoder(nn.Module):
heads: int,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary"
label_type: str = "binary",
attention_mode: str = "sdpa"
):
super().__init__()
@ -446,7 +463,8 @@ class CrossAttentionDecoder(nn.Module):
width=width,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
qk_norm=qk_norm,
attention_mode=attention_mode
)
self.ln_post = nn.LayerNorm(width)
@ -514,6 +532,7 @@ class ShapeVAE(nn.Module):
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
attention_mode: str = "sdpa"
):
super().__init__()
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
@ -539,11 +558,14 @@ class ShapeVAE(nn.Module):
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
attention_mode=attention_mode
)
self.scale_factor = scale_factor
self.latent_shape = (num_latents, embed_dim)
self.attention_mode = attention_mode
def forward(self, latents):
latents = self.post_kl(latents)
latents = self.transformer(latents)

View File

@ -153,6 +153,7 @@ class Hunyuan3DDiTPipeline:
dtype=torch.float16,
use_safetensors=None,
compile_args=None,
attention_mode="sdpa",
**kwargs,
):
# load config
@ -179,17 +180,19 @@ class Hunyuan3DDiTPipeline:
ckpt[model_name][new_key] = value
else:
ckpt = torch.load(ckpt_path, map_location='cpu')
# load model
config['model']['params']['attention_mode'] = attention_mode
config['vae']['params']['attention_mode'] = attention_mode
with init_empty_weights():
model = instantiate_from_config(config['model'])
vae = instantiate_from_config(config['vae'])
conditioner = instantiate_from_config(config['conditioner'])
#model
#model.load_state_dict(ckpt['model'])
for name, param in model.named_parameters():
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
#vae
#vae.load_state_dict(ckpt['vae'])
for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])

View File

@ -73,6 +73,7 @@ class Hy3DModelLoader:
},
"optional": {
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
}
}
@ -81,7 +82,7 @@ class Hy3DModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper"
def loadmodel(self, model, compile_args=None):
def loadmodel(self, model, compile_args=None, attention_mode="sdpa"):
device = mm.get_torch_device()
offload_device=mm.unet_offload_device()
@ -93,7 +94,9 @@ class Hy3DModelLoader:
use_safetensors=True,
device=device,
offload_device=offload_device,
compile_args=compile_args)
compile_args=compile_args,
attention_mode=attention_mode)
return (pipe,)
class DownloadAndLoadHy3DDelightModel: