From 17d964aa3777a8dac71b77aecaa853c9808c0574 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 4 Feb 2025 16:28:35 +0200 Subject: [PATCH] Support CublasOps for decoding https://github.com/aredden/torch-cublas-hgemm on a 4090 this is ~25% faster --- hy3dgen/shapegen/models/vae.py | 76 +++++++++++++++++++++++----------- hy3dgen/shapegen/pipelines.py | 4 ++ nodes.py | 6 ++- 3 files changed, 59 insertions(+), 27 deletions(-) diff --git a/hy3dgen/shapegen/models/vae.py b/hy3dgen/shapegen/models/vae.py index 03a26a0..31e98be 100755 --- a/hy3dgen/shapegen/models/vae.py +++ b/hy3dgen/shapegen/models/vae.py @@ -36,6 +36,10 @@ try: except ImportError: sageattn = None from comfy.utils import ProgressBar +try: + from cublas_ops import CublasLinear +except ImportError: + CublasLinear = None class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts @@ -171,12 +175,14 @@ class MLP(nn.Module): self, *, width: int, output_width: int = None, - drop_path_rate: float = 0.0 + drop_path_rate: float = 0.0, + linear_func: nn.Module = nn.Linear ): super().__init__() + self.width = width - self.c_fc = nn.Linear(width, width * 4) - self.c_proj = nn.Linear(width * 4, output_width if output_width is not None else width) + self.c_fc = linear_func(width, width * 4) + self.c_proj = linear_func(width * 4, output_width if output_width is not None else width) self.gelu = nn.GELU() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() @@ -233,16 +239,18 @@ class MultiheadCrossAttention(nn.Module): data_width: Optional[int] = None, norm_layer=nn.LayerNorm, qk_norm: bool = False, - attention_mode: str = "sdpa" + attention_mode: str = "sdpa", + linear_func: nn.Module = nn.Linear ): super().__init__() + self.n_data = n_data self.width = width self.heads = heads self.data_width = width if data_width is None else data_width - self.c_q = nn.Linear(width, width, bias=qkv_bias) - self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) - self.c_proj = nn.Linear(width, width) + self.c_q = linear_func(width, width, bias=qkv_bias) + self.c_kv = linear_func(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = linear_func(width, width) self.attention = QKVMultiheadCrossAttention( heads=heads, n_data=n_data, @@ -271,9 +279,11 @@ class ResidualCrossAttentionBlock(nn.Module): qkv_bias: bool = True, norm_layer=nn.LayerNorm, qk_norm: bool = False, - attention_mode: str = "sdpa" + attention_mode: str = "sdpa", + linear_func: nn.Module = nn.Linear ): super().__init__() + if data_width is None: data_width = width @@ -288,12 +298,13 @@ class ResidualCrossAttentionBlock(nn.Module): qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, - attention_mode=self.attention_mode + attention_mode=self.attention_mode, + linear_func=linear_func ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) - self.mlp = MLP(width=width) + self.mlp = MLP(width=width, linear_func=linear_func) def forward(self, x: torch.Tensor, data: torch.Tensor): x = x + self.attn(self.ln_1(x), self.ln_2(data)) @@ -346,14 +357,15 @@ class MultiheadAttention(nn.Module): qkv_bias: bool, norm_layer=nn.LayerNorm, qk_norm: bool = False, - drop_path_rate: float = 0.0 + drop_path_rate: float = 0.0, + linear_func: nn.Module = nn.Linear ): super().__init__() self.n_ctx = n_ctx self.width = width self.heads = heads - self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) - self.c_proj = nn.Linear(width, width) + self.c_qkv = linear_func(width, width * 3, bias=qkv_bias) + self.c_proj = linear_func(width, width) self.attention = QKVMultiheadAttention( heads=heads, n_ctx=n_ctx, @@ -381,6 +393,7 @@ class ResidualAttentionBlock(nn.Module): norm_layer=nn.LayerNorm, qk_norm: bool = False, drop_path_rate: float = 0.0, + linear_func: nn.Module = nn.Linear ): super().__init__() self.attn = MultiheadAttention( @@ -390,10 +403,11 @@ class ResidualAttentionBlock(nn.Module): qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, - drop_path_rate=drop_path_rate + drop_path_rate=drop_path_rate, + linear_func=linear_func ) self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) - self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate, linear_func=linear_func) self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) def forward(self, x: torch.Tensor): @@ -413,9 +427,11 @@ class Transformer(nn.Module): qkv_bias: bool = True, norm_layer=nn.LayerNorm, qk_norm: bool = False, - drop_path_rate: float = 0.0 + drop_path_rate: float = 0.0, + linear_func: nn.Module = nn.Linear ): super().__init__() + self.n_ctx = n_ctx self.width = width self.layers = layers @@ -428,7 +444,8 @@ class Transformer(nn.Module): qkv_bias=qkv_bias, norm_layer=norm_layer, qk_norm=qk_norm, - drop_path_rate=drop_path_rate + drop_path_rate=drop_path_rate, + linear_func=linear_func ) for _ in range(layers) ] @@ -453,13 +470,15 @@ class CrossAttentionDecoder(nn.Module): qkv_bias: bool = True, qk_norm: bool = False, label_type: str = "binary", - attention_mode: str = "sdpa" + attention_mode: str = "sdpa", + linear_func: nn.Module = nn.Linear ): super().__init__() + self.fourier_embedder = fourier_embedder - self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width) + self.query_proj = linear_func(self.fourier_embedder.out_dim, width) self.cross_attn_decoder = ResidualCrossAttentionBlock( n_data=num_latents, @@ -467,11 +486,12 @@ class CrossAttentionDecoder(nn.Module): heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm, - attention_mode=attention_mode + attention_mode=attention_mode, + linear_func=linear_func ) self.ln_post = nn.LayerNorm(width) - self.output_proj = nn.Linear(width, out_channels) + self.output_proj = linear_func(width, out_channels) self.label_type = label_type def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): @@ -535,12 +555,16 @@ class ShapeVAE(nn.Module): label_type: str = "binary", drop_path_rate: float = 0.0, scale_factor: float = 1.0, - attention_mode: str = "sdpa" + attention_mode: str = "sdpa", + cublas_ops: bool = False ): super().__init__() + linear_func = CublasLinear if cublas_ops else nn.Linear + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) - self.post_kl = nn.Linear(embed_dim, width) + self.post_kl = linear_func(embed_dim, width) self.transformer = Transformer( n_ctx=num_latents, @@ -549,7 +573,8 @@ class ShapeVAE(nn.Module): heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm, - drop_path_rate=drop_path_rate + drop_path_rate=drop_path_rate, + linear_func=linear_func ) self.geo_decoder = CrossAttentionDecoder( @@ -561,7 +586,8 @@ class ShapeVAE(nn.Module): qkv_bias=qkv_bias, qk_norm=qk_norm, label_type=label_type, - attention_mode=attention_mode + attention_mode=attention_mode, + linear_func=linear_func ) self.scale_factor = scale_factor diff --git a/hy3dgen/shapegen/pipelines.py b/hy3dgen/shapegen/pipelines.py index 8c3f856..6af6475 100755 --- a/hy3dgen/shapegen/pipelines.py +++ b/hy3dgen/shapegen/pipelines.py @@ -154,6 +154,7 @@ class Hunyuan3DDiTPipeline: use_safetensors=None, compile_args=None, attention_mode="sdpa", + cublas_ops=False, **kwargs, ): # load config @@ -188,6 +189,9 @@ class Hunyuan3DDiTPipeline: config['model']['params']['guidance_embed'] = True config['model']['params']['attention_mode'] = attention_mode config['vae']['params']['attention_mode'] = attention_mode + + if cublas_ops: + config['vae']['params']['cublas_ops'] = True with init_empty_weights(): model = instantiate_from_config(config['model']) diff --git a/nodes.py b/nodes.py index ba943bb..6066528 100644 --- a/nodes.py +++ b/nodes.py @@ -112,6 +112,7 @@ class Hy3DModelLoader: "optional": { "compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}), "attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), + "cublas_ops": ("BOOLEAN", {"default": False, "tooltip": "Enable optimized cublas linear layers, speeds up decoding: https://github.com/aredden/torch-cublas-hgemm"}), } } @@ -120,7 +121,7 @@ class Hy3DModelLoader: FUNCTION = "loadmodel" CATEGORY = "Hunyuan3DWrapper" - def loadmodel(self, model, compile_args=None, attention_mode="sdpa"): + def loadmodel(self, model, compile_args=None, attention_mode="sdpa", cublas_ops=False): device = mm.get_torch_device() offload_device=mm.unet_offload_device() @@ -133,7 +134,8 @@ class Hy3DModelLoader: device=device, offload_device=offload_device, compile_args=compile_args, - attention_mode=attention_mode) + attention_mode=attention_mode, + cublas_ops=cublas_ops) return (pipe, vae,)