mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-09 04:44:26 +08:00
support sageattn for mesh gen
This commit is contained in:
parent
d454e92b9b
commit
a86f12236e
@ -31,11 +31,21 @@ from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
return x
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ImportError:
|
||||
sageattn = None
|
||||
|
||||
def attention_sage(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
||||
x = sageattn(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
return x
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
@ -151,7 +161,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool = False,
|
||||
attention_mode: str = "sdpa",
|
||||
):
|
||||
if attention_mode == "sdpa":
|
||||
self.attention_func = attention
|
||||
elif attention_mode == "sageattn":
|
||||
self.attention_func = attention_sage
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
@ -198,7 +213,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = self.attention_func(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
@ -221,8 +236,13 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: Optional[float] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
):
|
||||
super().__init__()
|
||||
if attention_mode == "sdpa":
|
||||
self.attention_func = attention
|
||||
elif attention_mode == "sageattn":
|
||||
self.attention_func = attention_sage
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
@ -253,7 +273,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = self.attention_func(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
@ -288,6 +308,7 @@ class Hunyuan3DDiT(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
time_factor: float = 1000,
|
||||
ckpt_path: Optional[str] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -303,6 +324,7 @@ class Hunyuan3DDiT(nn.Module):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.time_factor = time_factor
|
||||
self.out_channels = self.in_channels
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -324,6 +346,7 @@ class Hunyuan3DDiT(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attention_mode=self.attention_mode,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
@ -335,6 +358,7 @@ class Hunyuan3DDiT(nn.Module):
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_mode=self.attention_mode,
|
||||
)
|
||||
for _ in range(depth_single_blocks)
|
||||
]
|
||||
@ -342,6 +366,8 @@ class Hunyuan3DDiT(nn.Module):
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
|
||||
|
||||
if ckpt_path is not None:
|
||||
print('restored denoiser ckpt', ckpt_path)
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from skimage import measure
|
||||
from tqdm import tqdm
|
||||
|
||||
from sageattention import sageattn
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
@ -189,13 +189,18 @@ class QKVMultiheadCrossAttention(nn.Module):
|
||||
n_data: Optional[int] = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm
|
||||
norm_layer=nn.LayerNorm,
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
if attention_mode == "sdpa":
|
||||
self.attention_func = F.scaled_dot_product_attention
|
||||
elif attention_mode == "sageattn":
|
||||
self.attention_func = sageattn
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
@ -209,7 +214,7 @@ class QKVMultiheadCrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
|
||||
return out
|
||||
|
||||
@ -224,7 +229,8 @@ class MultiheadCrossAttention(nn.Module):
|
||||
n_data: Optional[int] = None,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
qk_norm: bool = False,
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.n_data = n_data
|
||||
@ -239,7 +245,8 @@ class MultiheadCrossAttention(nn.Module):
|
||||
n_data=n_data,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
qk_norm=qk_norm,
|
||||
attention_mode=attention_mode
|
||||
)
|
||||
|
||||
def forward(self, x, data):
|
||||
@ -260,12 +267,15 @@ class ResidualCrossAttentionBlock(nn.Module):
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
qk_norm: bool = False,
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
n_data=n_data,
|
||||
@ -274,7 +284,8 @@ class ResidualCrossAttentionBlock(nn.Module):
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
qk_norm=qk_norm,
|
||||
attention_mode=self.attention_mode
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
@ -295,13 +306,18 @@ class QKVMultiheadAttention(nn.Module):
|
||||
n_ctx: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm
|
||||
norm_layer=nn.LayerNorm,
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_ctx = n_ctx
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
if attention_mode == "sdpa":
|
||||
self.attention_func = F.scaled_dot_product_attention
|
||||
elif attention_mode == "sageattn":
|
||||
self.attention_func = sageattn
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
@ -313,7 +329,7 @@ class QKVMultiheadAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
@ -433,7 +449,8 @@ class CrossAttentionDecoder(nn.Module):
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
label_type: str = "binary",
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -446,7 +463,8 @@ class CrossAttentionDecoder(nn.Module):
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
qk_norm=qk_norm,
|
||||
attention_mode=attention_mode
|
||||
)
|
||||
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
@ -514,6 +532,7 @@ class ShapeVAE(nn.Module):
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
@ -539,11 +558,14 @@ class ShapeVAE(nn.Module):
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
attention_mode=attention_mode
|
||||
)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.latent_shape = (num_latents, embed_dim)
|
||||
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def forward(self, latents):
|
||||
latents = self.post_kl(latents)
|
||||
latents = self.transformer(latents)
|
||||
|
||||
@ -153,6 +153,7 @@ class Hunyuan3DDiTPipeline:
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
compile_args=None,
|
||||
attention_mode="sdpa",
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
@ -179,17 +180,19 @@ class Hunyuan3DDiTPipeline:
|
||||
ckpt[model_name][new_key] = value
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
|
||||
# load model
|
||||
config['model']['params']['attention_mode'] = attention_mode
|
||||
config['vae']['params']['attention_mode'] = attention_mode
|
||||
with init_empty_weights():
|
||||
model = instantiate_from_config(config['model'])
|
||||
vae = instantiate_from_config(config['vae'])
|
||||
conditioner = instantiate_from_config(config['conditioner'])
|
||||
#model
|
||||
#model.load_state_dict(ckpt['model'])
|
||||
for name, param in model.named_parameters():
|
||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
||||
#vae
|
||||
#vae.load_state_dict(ckpt['vae'])
|
||||
for name, param in vae.named_parameters():
|
||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
||||
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -73,6 +73,7 @@ class Hy3DModelLoader:
|
||||
},
|
||||
"optional": {
|
||||
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
||||
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,7 +82,7 @@ class Hy3DModelLoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def loadmodel(self, model, compile_args=None):
|
||||
def loadmodel(self, model, compile_args=None, attention_mode="sdpa"):
|
||||
device = mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
@ -93,7 +94,9 @@ class Hy3DModelLoader:
|
||||
use_safetensors=True,
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
compile_args=compile_args)
|
||||
compile_args=compile_args,
|
||||
attention_mode=attention_mode)
|
||||
|
||||
return (pipe,)
|
||||
|
||||
class DownloadAndLoadHy3DDelightModel:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user