mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-09 21:04:32 +08:00
Support CublasOps for decoding
https://github.com/aredden/torch-cublas-hgemm on a 4090 this is ~25% faster
This commit is contained in:
parent
a3d0277aed
commit
17d964aa37
@ -36,6 +36,10 @@ try:
|
||||
except ImportError:
|
||||
sageattn = None
|
||||
from comfy.utils import ProgressBar
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
except ImportError:
|
||||
CublasLinear = None
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
@ -171,12 +175,14 @@ class MLP(nn.Module):
|
||||
self, *,
|
||||
width: int,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
drop_path_rate: float = 0.0,
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.width = width
|
||||
self.c_fc = nn.Linear(width, width * 4)
|
||||
self.c_proj = nn.Linear(width * 4, output_width if output_width is not None else width)
|
||||
self.c_fc = linear_func(width, width * 4)
|
||||
self.c_proj = linear_func(width * 4, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
@ -233,16 +239,18 @@ class MultiheadCrossAttention(nn.Module):
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
attention_mode: str = "sdpa"
|
||||
attention_mode: str = "sdpa",
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_data = n_data
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.c_q = linear_func(width, width, bias=qkv_bias)
|
||||
self.c_kv = linear_func(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = linear_func(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
n_data=n_data,
|
||||
@ -271,9 +279,11 @@ class ResidualCrossAttentionBlock(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
attention_mode: str = "sdpa"
|
||||
attention_mode: str = "sdpa",
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
@ -288,12 +298,13 @@ class ResidualCrossAttentionBlock(nn.Module):
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
attention_mode=self.attention_mode
|
||||
attention_mode=self.attention_mode,
|
||||
linear_func=linear_func
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width)
|
||||
self.mlp = MLP(width=width, linear_func=linear_func)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
@ -346,14 +357,15 @@ class MultiheadAttention(nn.Module):
|
||||
qkv_bias: bool,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
drop_path_rate: float = 0.0,
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.c_qkv = linear_func(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = linear_func(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
n_ctx=n_ctx,
|
||||
@ -381,6 +393,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
@ -390,10 +403,11 @@ class ResidualAttentionBlock(nn.Module):
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
drop_path_rate=drop_path_rate,
|
||||
linear_func=linear_func
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate, linear_func=linear_func)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -413,9 +427,11 @@ class Transformer(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
drop_path_rate: float = 0.0,
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
@ -428,7 +444,8 @@ class Transformer(nn.Module):
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
drop_path_rate=drop_path_rate,
|
||||
linear_func=linear_func
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
@ -453,13 +470,15 @@ class CrossAttentionDecoder(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
attention_mode: str = "sdpa"
|
||||
attention_mode: str = "sdpa",
|
||||
linear_func: nn.Module = nn.Linear
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.fourier_embedder = fourier_embedder
|
||||
|
||||
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
||||
self.query_proj = linear_func(self.fourier_embedder.out_dim, width)
|
||||
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
n_data=num_latents,
|
||||
@ -467,11 +486,12 @@ class CrossAttentionDecoder(nn.Module):
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attention_mode=attention_mode
|
||||
attention_mode=attention_mode,
|
||||
linear_func=linear_func
|
||||
)
|
||||
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
self.output_proj = nn.Linear(width, out_channels)
|
||||
self.output_proj = linear_func(width, out_channels)
|
||||
self.label_type = label_type
|
||||
|
||||
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
||||
@ -535,12 +555,16 @@ class ShapeVAE(nn.Module):
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
attention_mode: str = "sdpa"
|
||||
attention_mode: str = "sdpa",
|
||||
cublas_ops: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
linear_func = CublasLinear if cublas_ops else nn.Linear
|
||||
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.post_kl = nn.Linear(embed_dim, width)
|
||||
self.post_kl = linear_func(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
n_ctx=num_latents,
|
||||
@ -549,7 +573,8 @@ class ShapeVAE(nn.Module):
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
drop_path_rate=drop_path_rate,
|
||||
linear_func=linear_func
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
@ -561,7 +586,8 @@ class ShapeVAE(nn.Module):
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
attention_mode=attention_mode
|
||||
attention_mode=attention_mode,
|
||||
linear_func=linear_func
|
||||
)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
@ -154,6 +154,7 @@ class Hunyuan3DDiTPipeline:
|
||||
use_safetensors=None,
|
||||
compile_args=None,
|
||||
attention_mode="sdpa",
|
||||
cublas_ops=False,
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
@ -188,6 +189,9 @@ class Hunyuan3DDiTPipeline:
|
||||
config['model']['params']['guidance_embed'] = True
|
||||
config['model']['params']['attention_mode'] = attention_mode
|
||||
config['vae']['params']['attention_mode'] = attention_mode
|
||||
|
||||
if cublas_ops:
|
||||
config['vae']['params']['cublas_ops'] = True
|
||||
|
||||
with init_empty_weights():
|
||||
model = instantiate_from_config(config['model'])
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -112,6 +112,7 @@ class Hy3DModelLoader:
|
||||
"optional": {
|
||||
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
||||
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
||||
"cublas_ops": ("BOOLEAN", {"default": False, "tooltip": "Enable optimized cublas linear layers, speeds up decoding: https://github.com/aredden/torch-cublas-hgemm"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,7 +121,7 @@ class Hy3DModelLoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def loadmodel(self, model, compile_args=None, attention_mode="sdpa"):
|
||||
def loadmodel(self, model, compile_args=None, attention_mode="sdpa", cublas_ops=False):
|
||||
device = mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
@ -133,7 +134,8 @@ class Hy3DModelLoader:
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
compile_args=compile_args,
|
||||
attention_mode=attention_mode)
|
||||
attention_mode=attention_mode,
|
||||
cublas_ops=cublas_ops)
|
||||
|
||||
return (pipe, vae,)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user