support sageattn for mesh gen

This commit is contained in:
kijai 2025-01-25 19:56:54 +02:00
parent d454e92b9b
commit a86f12236e
4 changed files with 71 additions and 17 deletions

View File

@ -31,11 +31,21 @@ from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)") x = rearrange(x, "B H L D -> B L (H D)")
return x return x
try:
from sageattention import sageattn
except ImportError:
sageattn = None
def attention_sage(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
x = sageattn(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
""" """
@ -151,7 +161,12 @@ class DoubleStreamBlock(nn.Module):
num_heads: int, num_heads: int,
mlp_ratio: float, mlp_ratio: float,
qkv_bias: bool = False, qkv_bias: bool = False,
attention_mode: str = "sdpa",
): ):
if attention_mode == "sdpa":
self.attention_func = attention
elif attention_mode == "sageattn":
self.attention_func = attention_sage
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads self.num_heads = num_heads
@ -198,7 +213,7 @@ class DoubleStreamBlock(nn.Module):
k = torch.cat((txt_k, img_k), dim=2) k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2) v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe) attn = self.attention_func(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod1.gate * self.img_attn.proj(img_attn)
@ -221,8 +236,13 @@ class SingleStreamBlock(nn.Module):
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None, qk_scale: Optional[float] = None,
attention_mode: str = "sdpa",
): ):
super().__init__() super().__init__()
if attention_mode == "sdpa":
self.attention_func = attention
elif attention_mode == "sageattn":
self.attention_func = attention_sage
self.hidden_dim = hidden_size self.hidden_dim = hidden_size
self.num_heads = num_heads self.num_heads = num_heads
@ -253,7 +273,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe) attn = self.attention_func(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output return x + mod.gate * output
@ -288,6 +308,7 @@ class Hunyuan3DDiT(nn.Module):
qkv_bias: bool = True, qkv_bias: bool = True,
time_factor: float = 1000, time_factor: float = 1000,
ckpt_path: Optional[str] = None, ckpt_path: Optional[str] = None,
attention_mode: str = "sdpa",
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -303,6 +324,7 @@ class Hunyuan3DDiT(nn.Module):
self.qkv_bias = qkv_bias self.qkv_bias = qkv_bias
self.time_factor = time_factor self.time_factor = time_factor
self.out_channels = self.in_channels self.out_channels = self.in_channels
self.attention_mode = attention_mode
if hidden_size % num_heads != 0: if hidden_size % num_heads != 0:
raise ValueError( raise ValueError(
@ -324,6 +346,7 @@ class Hunyuan3DDiT(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
attention_mode=self.attention_mode,
) )
for _ in range(depth) for _ in range(depth)
] ]
@ -335,6 +358,7 @@ class Hunyuan3DDiT(nn.Module):
self.hidden_size, self.hidden_size,
self.num_heads, self.num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
attention_mode=self.attention_mode,
) )
for _ in range(depth_single_blocks) for _ in range(depth_single_blocks)
] ]
@ -342,6 +366,8 @@ class Hunyuan3DDiT(nn.Module):
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
if ckpt_path is not None: if ckpt_path is not None:
print('restored denoiser ckpt', ckpt_path) print('restored denoiser ckpt', ckpt_path)

View File

@ -31,7 +31,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from skimage import measure from skimage import measure
from tqdm import tqdm from tqdm import tqdm
from sageattention import sageattn
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
class FourierEmbedder(nn.Module): class FourierEmbedder(nn.Module):
@ -189,13 +189,18 @@ class QKVMultiheadCrossAttention(nn.Module):
n_data: Optional[int] = None, n_data: Optional[int] = None,
width=None, width=None,
qk_norm=False, qk_norm=False,
norm_layer=nn.LayerNorm norm_layer=nn.LayerNorm,
attention_mode: str = "sdpa"
): ):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
self.n_data = n_data self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
if attention_mode == "sdpa":
self.attention_func = F.scaled_dot_product_attention
elif attention_mode == "sageattn":
self.attention_func = sageattn
def forward(self, q, kv): def forward(self, q, kv):
_, n_ctx, _ = q.shape _, n_ctx, _ = q.shape
@ -209,7 +214,7 @@ class QKVMultiheadCrossAttention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out return out
@ -224,7 +229,8 @@ class MultiheadCrossAttention(nn.Module):
n_data: Optional[int] = None, n_data: Optional[int] = None,
data_width: Optional[int] = None, data_width: Optional[int] = None,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
qk_norm: bool = False qk_norm: bool = False,
attention_mode: str = "sdpa"
): ):
super().__init__() super().__init__()
self.n_data = n_data self.n_data = n_data
@ -239,7 +245,8 @@ class MultiheadCrossAttention(nn.Module):
n_data=n_data, n_data=n_data,
width=width, width=width,
norm_layer=norm_layer, norm_layer=norm_layer,
qk_norm=qk_norm qk_norm=qk_norm,
attention_mode=attention_mode
) )
def forward(self, x, data): def forward(self, x, data):
@ -260,13 +267,16 @@ class ResidualCrossAttentionBlock(nn.Module):
data_width: Optional[int] = None, data_width: Optional[int] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
qk_norm: bool = False qk_norm: bool = False,
attention_mode: str = "sdpa"
): ):
super().__init__() super().__init__()
if data_width is None: if data_width is None:
data_width = width data_width = width
self.attention_mode = attention_mode
self.attn = MultiheadCrossAttention( self.attn = MultiheadCrossAttention(
n_data=n_data, n_data=n_data,
width=width, width=width,
@ -274,7 +284,8 @@ class ResidualCrossAttentionBlock(nn.Module):
data_width=data_width, data_width=data_width,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
norm_layer=norm_layer, norm_layer=norm_layer,
qk_norm=qk_norm qk_norm=qk_norm,
attention_mode=self.attention_mode
) )
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
@ -295,13 +306,18 @@ class QKVMultiheadAttention(nn.Module):
n_ctx: int, n_ctx: int,
width=None, width=None,
qk_norm=False, qk_norm=False,
norm_layer=nn.LayerNorm norm_layer=nn.LayerNorm,
attention_mode: str = "sdpa"
): ):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
if attention_mode == "sdpa":
self.attention_func = F.scaled_dot_product_attention
elif attention_mode == "sageattn":
self.attention_func = sageattn
def forward(self, qkv): def forward(self, qkv):
bs, n_ctx, width = qkv.shape bs, n_ctx, width = qkv.shape
@ -313,7 +329,7 @@ class QKVMultiheadAttention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out return out
@ -433,7 +449,8 @@ class CrossAttentionDecoder(nn.Module):
heads: int, heads: int,
qkv_bias: bool = True, qkv_bias: bool = True,
qk_norm: bool = False, qk_norm: bool = False,
label_type: str = "binary" label_type: str = "binary",
attention_mode: str = "sdpa"
): ):
super().__init__() super().__init__()
@ -446,7 +463,8 @@ class CrossAttentionDecoder(nn.Module):
width=width, width=width,
heads=heads, heads=heads,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm qk_norm=qk_norm,
attention_mode=attention_mode
) )
self.ln_post = nn.LayerNorm(width) self.ln_post = nn.LayerNorm(width)
@ -514,6 +532,7 @@ class ShapeVAE(nn.Module):
label_type: str = "binary", label_type: str = "binary",
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
scale_factor: float = 1.0, scale_factor: float = 1.0,
attention_mode: str = "sdpa"
): ):
super().__init__() super().__init__()
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
@ -539,11 +558,14 @@ class ShapeVAE(nn.Module):
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm, qk_norm=qk_norm,
label_type=label_type, label_type=label_type,
attention_mode=attention_mode
) )
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.latent_shape = (num_latents, embed_dim) self.latent_shape = (num_latents, embed_dim)
self.attention_mode = attention_mode
def forward(self, latents): def forward(self, latents):
latents = self.post_kl(latents) latents = self.post_kl(latents)
latents = self.transformer(latents) latents = self.transformer(latents)

View File

@ -153,6 +153,7 @@ class Hunyuan3DDiTPipeline:
dtype=torch.float16, dtype=torch.float16,
use_safetensors=None, use_safetensors=None,
compile_args=None, compile_args=None,
attention_mode="sdpa",
**kwargs, **kwargs,
): ):
# load config # load config
@ -179,17 +180,19 @@ class Hunyuan3DDiTPipeline:
ckpt[model_name][new_key] = value ckpt[model_name][new_key] = value
else: else:
ckpt = torch.load(ckpt_path, map_location='cpu') ckpt = torch.load(ckpt_path, map_location='cpu')
# load model # load model
config['model']['params']['attention_mode'] = attention_mode
config['vae']['params']['attention_mode'] = attention_mode
with init_empty_weights(): with init_empty_weights():
model = instantiate_from_config(config['model']) model = instantiate_from_config(config['model'])
vae = instantiate_from_config(config['vae']) vae = instantiate_from_config(config['vae'])
conditioner = instantiate_from_config(config['conditioner']) conditioner = instantiate_from_config(config['conditioner'])
#model #model
#model.load_state_dict(ckpt['model'])
for name, param in model.named_parameters(): for name, param in model.named_parameters():
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name]) set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
#vae #vae
#vae.load_state_dict(ckpt['vae'])
for name, param in vae.named_parameters(): for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name]) set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])

View File

@ -73,6 +73,7 @@ class Hy3DModelLoader:
}, },
"optional": { "optional": {
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}), "compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
} }
} }
@ -81,7 +82,7 @@ class Hy3DModelLoader:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
def loadmodel(self, model, compile_args=None): def loadmodel(self, model, compile_args=None, attention_mode="sdpa"):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device=mm.unet_offload_device() offload_device=mm.unet_offload_device()
@ -93,7 +94,9 @@ class Hy3DModelLoader:
use_safetensors=True, use_safetensors=True,
device=device, device=device,
offload_device=offload_device, offload_device=offload_device,
compile_args=compile_args) compile_args=compile_args,
attention_mode=attention_mode)
return (pipe,) return (pipe,)
class DownloadAndLoadHy3DDelightModel: class DownloadAndLoadHy3DDelightModel: