mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-16 16:24:26 +08:00
support sageattn for mesh gen
This commit is contained in:
parent
d454e92b9b
commit
a86f12236e
@ -31,11 +31,21 @@ from einops import rearrange
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
x = rearrange(x, "B H L D -> B L (H D)")
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
except ImportError:
|
||||||
|
sageattn = None
|
||||||
|
|
||||||
|
def attention_sage(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
||||||
|
x = sageattn(q, k, v)
|
||||||
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
return x
|
||||||
|
|
||||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||||
"""
|
"""
|
||||||
@ -151,7 +161,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float,
|
mlp_ratio: float,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
|
attention_mode: str = "sdpa",
|
||||||
):
|
):
|
||||||
|
if attention_mode == "sdpa":
|
||||||
|
self.attention_func = attention
|
||||||
|
elif attention_mode == "sageattn":
|
||||||
|
self.attention_func = attention_sage
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -198,7 +213,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
k = torch.cat((txt_k, img_k), dim=2)
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = self.attention_func(q, k, v, pe=pe)
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
@ -221,8 +236,13 @@ class SingleStreamBlock(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
qk_scale: Optional[float] = None,
|
qk_scale: Optional[float] = None,
|
||||||
|
attention_mode: str = "sdpa",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if attention_mode == "sdpa":
|
||||||
|
self.attention_func = attention
|
||||||
|
elif attention_mode == "sageattn":
|
||||||
|
self.attention_func = attention_sage
|
||||||
|
|
||||||
self.hidden_dim = hidden_size
|
self.hidden_dim = hidden_size
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -253,7 +273,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = self.attention_func(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
return x + mod.gate * output
|
return x + mod.gate * output
|
||||||
@ -288,6 +308,7 @@ class Hunyuan3DDiT(nn.Module):
|
|||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
time_factor: float = 1000,
|
time_factor: float = 1000,
|
||||||
ckpt_path: Optional[str] = None,
|
ckpt_path: Optional[str] = None,
|
||||||
|
attention_mode: str = "sdpa",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -303,6 +324,7 @@ class Hunyuan3DDiT(nn.Module):
|
|||||||
self.qkv_bias = qkv_bias
|
self.qkv_bias = qkv_bias
|
||||||
self.time_factor = time_factor
|
self.time_factor = time_factor
|
||||||
self.out_channels = self.in_channels
|
self.out_channels = self.in_channels
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
if hidden_size % num_heads != 0:
|
if hidden_size % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -324,6 +346,7 @@ class Hunyuan3DDiT(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
|
attention_mode=self.attention_mode,
|
||||||
)
|
)
|
||||||
for _ in range(depth)
|
for _ in range(depth)
|
||||||
]
|
]
|
||||||
@ -335,6 +358,7 @@ class Hunyuan3DDiT(nn.Module):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
|
attention_mode=self.attention_mode,
|
||||||
)
|
)
|
||||||
for _ in range(depth_single_blocks)
|
for _ in range(depth_single_blocks)
|
||||||
]
|
]
|
||||||
@ -342,6 +366,8 @@ class Hunyuan3DDiT(nn.Module):
|
|||||||
|
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
print('restored denoiser ckpt', ckpt_path)
|
print('restored denoiser ckpt', ckpt_path)
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from skimage import measure
|
from skimage import measure
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from sageattention import sageattn
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
class FourierEmbedder(nn.Module):
|
class FourierEmbedder(nn.Module):
|
||||||
@ -189,13 +189,18 @@ class QKVMultiheadCrossAttention(nn.Module):
|
|||||||
n_data: Optional[int] = None,
|
n_data: Optional[int] = None,
|
||||||
width=None,
|
width=None,
|
||||||
qk_norm=False,
|
qk_norm=False,
|
||||||
norm_layer=nn.LayerNorm
|
norm_layer=nn.LayerNorm,
|
||||||
|
attention_mode: str = "sdpa"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.n_data = n_data
|
self.n_data = n_data
|
||||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
if attention_mode == "sdpa":
|
||||||
|
self.attention_func = F.scaled_dot_product_attention
|
||||||
|
elif attention_mode == "sageattn":
|
||||||
|
self.attention_func = sageattn
|
||||||
|
|
||||||
def forward(self, q, kv):
|
def forward(self, q, kv):
|
||||||
_, n_ctx, _ = q.shape
|
_, n_ctx, _ = q.shape
|
||||||
@ -209,7 +214,7 @@ class QKVMultiheadCrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -224,7 +229,8 @@ class MultiheadCrossAttention(nn.Module):
|
|||||||
n_data: Optional[int] = None,
|
n_data: Optional[int] = None,
|
||||||
data_width: Optional[int] = None,
|
data_width: Optional[int] = None,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
qk_norm: bool = False
|
qk_norm: bool = False,
|
||||||
|
attention_mode: str = "sdpa"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_data = n_data
|
self.n_data = n_data
|
||||||
@ -239,7 +245,8 @@ class MultiheadCrossAttention(nn.Module):
|
|||||||
n_data=n_data,
|
n_data=n_data,
|
||||||
width=width,
|
width=width,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
qk_norm=qk_norm
|
qk_norm=qk_norm,
|
||||||
|
attention_mode=attention_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, data):
|
def forward(self, x, data):
|
||||||
@ -260,13 +267,16 @@ class ResidualCrossAttentionBlock(nn.Module):
|
|||||||
data_width: Optional[int] = None,
|
data_width: Optional[int] = None,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
qk_norm: bool = False
|
qk_norm: bool = False,
|
||||||
|
attention_mode: str = "sdpa"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if data_width is None:
|
if data_width is None:
|
||||||
data_width = width
|
data_width = width
|
||||||
|
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
self.attn = MultiheadCrossAttention(
|
self.attn = MultiheadCrossAttention(
|
||||||
n_data=n_data,
|
n_data=n_data,
|
||||||
width=width,
|
width=width,
|
||||||
@ -274,7 +284,8 @@ class ResidualCrossAttentionBlock(nn.Module):
|
|||||||
data_width=data_width,
|
data_width=data_width,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
qk_norm=qk_norm
|
qk_norm=qk_norm,
|
||||||
|
attention_mode=self.attention_mode
|
||||||
)
|
)
|
||||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||||
@ -295,13 +306,18 @@ class QKVMultiheadAttention(nn.Module):
|
|||||||
n_ctx: int,
|
n_ctx: int,
|
||||||
width=None,
|
width=None,
|
||||||
qk_norm=False,
|
qk_norm=False,
|
||||||
norm_layer=nn.LayerNorm
|
norm_layer=nn.LayerNorm,
|
||||||
|
attention_mode: str = "sdpa"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
if attention_mode == "sdpa":
|
||||||
|
self.attention_func = F.scaled_dot_product_attention
|
||||||
|
elif attention_mode == "sageattn":
|
||||||
|
self.attention_func = sageattn
|
||||||
|
|
||||||
def forward(self, qkv):
|
def forward(self, qkv):
|
||||||
bs, n_ctx, width = qkv.shape
|
bs, n_ctx, width = qkv.shape
|
||||||
@ -313,7 +329,7 @@ class QKVMultiheadAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||||
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -433,7 +449,8 @@ class CrossAttentionDecoder(nn.Module):
|
|||||||
heads: int,
|
heads: int,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
label_type: str = "binary"
|
label_type: str = "binary",
|
||||||
|
attention_mode: str = "sdpa"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -446,7 +463,8 @@ class CrossAttentionDecoder(nn.Module):
|
|||||||
width=width,
|
width=width,
|
||||||
heads=heads,
|
heads=heads,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_norm=qk_norm
|
qk_norm=qk_norm,
|
||||||
|
attention_mode=attention_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ln_post = nn.LayerNorm(width)
|
self.ln_post = nn.LayerNorm(width)
|
||||||
@ -514,6 +532,7 @@ class ShapeVAE(nn.Module):
|
|||||||
label_type: str = "binary",
|
label_type: str = "binary",
|
||||||
drop_path_rate: float = 0.0,
|
drop_path_rate: float = 0.0,
|
||||||
scale_factor: float = 1.0,
|
scale_factor: float = 1.0,
|
||||||
|
attention_mode: str = "sdpa"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||||
@ -539,11 +558,14 @@ class ShapeVAE(nn.Module):
|
|||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
label_type=label_type,
|
label_type=label_type,
|
||||||
|
attention_mode=attention_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.latent_shape = (num_latents, embed_dim)
|
self.latent_shape = (num_latents, embed_dim)
|
||||||
|
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
def forward(self, latents):
|
def forward(self, latents):
|
||||||
latents = self.post_kl(latents)
|
latents = self.post_kl(latents)
|
||||||
latents = self.transformer(latents)
|
latents = self.transformer(latents)
|
||||||
|
|||||||
@ -153,6 +153,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
use_safetensors=None,
|
use_safetensors=None,
|
||||||
compile_args=None,
|
compile_args=None,
|
||||||
|
attention_mode="sdpa",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# load config
|
# load config
|
||||||
@ -179,17 +180,19 @@ class Hunyuan3DDiTPipeline:
|
|||||||
ckpt[model_name][new_key] = value
|
ckpt[model_name][new_key] = value
|
||||||
else:
|
else:
|
||||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
|
config['model']['params']['attention_mode'] = attention_mode
|
||||||
|
config['vae']['params']['attention_mode'] = attention_mode
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = instantiate_from_config(config['model'])
|
model = instantiate_from_config(config['model'])
|
||||||
vae = instantiate_from_config(config['vae'])
|
vae = instantiate_from_config(config['vae'])
|
||||||
conditioner = instantiate_from_config(config['conditioner'])
|
conditioner = instantiate_from_config(config['conditioner'])
|
||||||
#model
|
#model
|
||||||
#model.load_state_dict(ckpt['model'])
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
||||||
#vae
|
#vae
|
||||||
#vae.load_state_dict(ckpt['vae'])
|
|
||||||
for name, param in vae.named_parameters():
|
for name, param in vae.named_parameters():
|
||||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
||||||
|
|
||||||
|
|||||||
7
nodes.py
7
nodes.py
@ -73,6 +73,7 @@ class Hy3DModelLoader:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
||||||
|
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ class Hy3DModelLoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "Hunyuan3DWrapper"
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, compile_args=None):
|
def loadmodel(self, model, compile_args=None, attention_mode="sdpa"):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device=mm.unet_offload_device()
|
offload_device=mm.unet_offload_device()
|
||||||
|
|
||||||
@ -93,7 +94,9 @@ class Hy3DModelLoader:
|
|||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
device=device,
|
device=device,
|
||||||
offload_device=offload_device,
|
offload_device=offload_device,
|
||||||
compile_args=compile_args)
|
compile_args=compile_args,
|
||||||
|
attention_mode=attention_mode)
|
||||||
|
|
||||||
return (pipe,)
|
return (pipe,)
|
||||||
|
|
||||||
class DownloadAndLoadHy3DDelightModel:
|
class DownloadAndLoadHy3DDelightModel:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user