diff --git a/hy3dgen/shapegen/models/hunyuan3ddit.py b/hy3dgen/shapegen/models/hunyuan3ddit.py index d1c7786..d2fd277 100755 --- a/hy3dgen/shapegen/models/hunyuan3ddit.py +++ b/hy3dgen/shapegen/models/hunyuan3ddit.py @@ -31,11 +31,21 @@ from einops import rearrange from torch import Tensor, nn + def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor: x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") return x +try: + from sageattention import sageattn +except ImportError: + sageattn = None + +def attention_sage(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor: + x = sageattn(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + return x def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): """ @@ -151,7 +161,12 @@ class DoubleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float, qkv_bias: bool = False, + attention_mode: str = "sdpa", ): + if attention_mode == "sdpa": + self.attention_func = attention + elif attention_mode == "sageattn": + self.attention_func = attention_sage super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads @@ -198,7 +213,7 @@ class DoubleStreamBlock(nn.Module): k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + attn = self.attention_func(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] img = img + img_mod1.gate * self.img_attn.proj(img_attn) @@ -221,8 +236,13 @@ class SingleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, qk_scale: Optional[float] = None, + attention_mode: str = "sdpa", ): super().__init__() + if attention_mode == "sdpa": + self.attention_func = attention + elif attention_mode == "sageattn": + self.attention_func = attention_sage self.hidden_dim = hidden_size self.num_heads = num_heads @@ -253,7 +273,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe) + attn = self.attention_func(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output @@ -288,6 +308,7 @@ class Hunyuan3DDiT(nn.Module): qkv_bias: bool = True, time_factor: float = 1000, ckpt_path: Optional[str] = None, + attention_mode: str = "sdpa", **kwargs, ): super().__init__() @@ -303,6 +324,7 @@ class Hunyuan3DDiT(nn.Module): self.qkv_bias = qkv_bias self.time_factor = time_factor self.out_channels = self.in_channels + self.attention_mode = attention_mode if hidden_size % num_heads != 0: raise ValueError( @@ -324,6 +346,7 @@ class Hunyuan3DDiT(nn.Module): self.num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + attention_mode=self.attention_mode, ) for _ in range(depth) ] @@ -335,6 +358,7 @@ class Hunyuan3DDiT(nn.Module): self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, + attention_mode=self.attention_mode, ) for _ in range(depth_single_blocks) ] @@ -342,6 +366,8 @@ class Hunyuan3DDiT(nn.Module): self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + if ckpt_path is not None: print('restored denoiser ckpt', ckpt_path) diff --git a/hy3dgen/shapegen/models/vae.py b/hy3dgen/shapegen/models/vae.py index 0d15caa..e1761ac 100755 --- a/hy3dgen/shapegen/models/vae.py +++ b/hy3dgen/shapegen/models/vae.py @@ -31,7 +31,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from skimage import measure from tqdm import tqdm - +from sageattention import sageattn from comfy.utils import ProgressBar class FourierEmbedder(nn.Module): @@ -189,13 +189,18 @@ class QKVMultiheadCrossAttention(nn.Module): n_data: Optional[int] = None, width=None, qk_norm=False, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, + attention_mode: str = "sdpa" ): super().__init__() self.heads = heads self.n_data = n_data self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + if attention_mode == "sdpa": + self.attention_func = F.scaled_dot_product_attention + elif attention_mode == "sageattn": + self.attention_func = sageattn def forward(self, q, kv): _, n_ctx, _ = q.shape @@ -209,7 +214,7 @@ class QKVMultiheadCrossAttention(nn.Module): k = self.k_norm(k) q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) - out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out @@ -224,7 +229,8 @@ class MultiheadCrossAttention(nn.Module): n_data: Optional[int] = None, data_width: Optional[int] = None, norm_layer=nn.LayerNorm, - qk_norm: bool = False + qk_norm: bool = False, + attention_mode: str = "sdpa" ): super().__init__() self.n_data = n_data @@ -239,7 +245,8 @@ class MultiheadCrossAttention(nn.Module): n_data=n_data, width=width, norm_layer=norm_layer, - qk_norm=qk_norm + qk_norm=qk_norm, + attention_mode=attention_mode ) def forward(self, x, data): @@ -260,12 +267,15 @@ class ResidualCrossAttentionBlock(nn.Module): data_width: Optional[int] = None, qkv_bias: bool = True, norm_layer=nn.LayerNorm, - qk_norm: bool = False + qk_norm: bool = False, + attention_mode: str = "sdpa" ): super().__init__() if data_width is None: data_width = width + + self.attention_mode = attention_mode self.attn = MultiheadCrossAttention( n_data=n_data, @@ -274,7 +284,8 @@ class ResidualCrossAttentionBlock(nn.Module): data_width=data_width, qkv_bias=qkv_bias, norm_layer=norm_layer, - qk_norm=qk_norm + qk_norm=qk_norm, + attention_mode=self.attention_mode ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) @@ -295,13 +306,18 @@ class QKVMultiheadAttention(nn.Module): n_ctx: int, width=None, qk_norm=False, - norm_layer=nn.LayerNorm + norm_layer=nn.LayerNorm, + attention_mode: str = "sdpa" ): super().__init__() self.heads = heads self.n_ctx = n_ctx self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + if attention_mode == "sdpa": + self.attention_func = F.scaled_dot_product_attention + elif attention_mode == "sageattn": + self.attention_func = sageattn def forward(self, qkv): bs, n_ctx, width = qkv.shape @@ -313,7 +329,7 @@ class QKVMultiheadAttention(nn.Module): k = self.k_norm(k) q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) - out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + out = self.attention_func(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out @@ -433,7 +449,8 @@ class CrossAttentionDecoder(nn.Module): heads: int, qkv_bias: bool = True, qk_norm: bool = False, - label_type: str = "binary" + label_type: str = "binary", + attention_mode: str = "sdpa" ): super().__init__() @@ -446,7 +463,8 @@ class CrossAttentionDecoder(nn.Module): width=width, heads=heads, qkv_bias=qkv_bias, - qk_norm=qk_norm + qk_norm=qk_norm, + attention_mode=attention_mode ) self.ln_post = nn.LayerNorm(width) @@ -514,6 +532,7 @@ class ShapeVAE(nn.Module): label_type: str = "binary", drop_path_rate: float = 0.0, scale_factor: float = 1.0, + attention_mode: str = "sdpa" ): super().__init__() self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) @@ -539,11 +558,14 @@ class ShapeVAE(nn.Module): qkv_bias=qkv_bias, qk_norm=qk_norm, label_type=label_type, + attention_mode=attention_mode ) self.scale_factor = scale_factor self.latent_shape = (num_latents, embed_dim) + self.attention_mode = attention_mode + def forward(self, latents): latents = self.post_kl(latents) latents = self.transformer(latents) diff --git a/hy3dgen/shapegen/pipelines.py b/hy3dgen/shapegen/pipelines.py index 5b5cd89..4435bf2 100755 --- a/hy3dgen/shapegen/pipelines.py +++ b/hy3dgen/shapegen/pipelines.py @@ -153,6 +153,7 @@ class Hunyuan3DDiTPipeline: dtype=torch.float16, use_safetensors=None, compile_args=None, + attention_mode="sdpa", **kwargs, ): # load config @@ -179,17 +180,19 @@ class Hunyuan3DDiTPipeline: ckpt[model_name][new_key] = value else: ckpt = torch.load(ckpt_path, map_location='cpu') + + # load model + config['model']['params']['attention_mode'] = attention_mode + config['vae']['params']['attention_mode'] = attention_mode with init_empty_weights(): model = instantiate_from_config(config['model']) vae = instantiate_from_config(config['vae']) conditioner = instantiate_from_config(config['conditioner']) #model - #model.load_state_dict(ckpt['model']) for name, param in model.named_parameters(): set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name]) #vae - #vae.load_state_dict(ckpt['vae']) for name, param in vae.named_parameters(): set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name]) diff --git a/nodes.py b/nodes.py index 84d793e..ce4abbb 100644 --- a/nodes.py +++ b/nodes.py @@ -73,6 +73,7 @@ class Hy3DModelLoader: }, "optional": { "compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}), + "attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), } } @@ -81,7 +82,7 @@ class Hy3DModelLoader: FUNCTION = "loadmodel" CATEGORY = "Hunyuan3DWrapper" - def loadmodel(self, model, compile_args=None): + def loadmodel(self, model, compile_args=None, attention_mode="sdpa"): device = mm.get_torch_device() offload_device=mm.unet_offload_device() @@ -93,7 +94,9 @@ class Hy3DModelLoader: use_safetensors=True, device=device, offload_device=offload_device, - compile_args=compile_args) + compile_args=compile_args, + attention_mode=attention_mode) + return (pipe,) class DownloadAndLoadHy3DDelightModel: