Support CublasOps for decoding

https://github.com/aredden/torch-cublas-hgemm

on a 4090 this is ~25% faster
This commit is contained in:
kijai 2025-02-04 16:28:35 +02:00
parent a3d0277aed
commit 17d964aa37
3 changed files with 59 additions and 27 deletions

View File

@ -36,6 +36,10 @@ try:
except ImportError:
sageattn = None
from comfy.utils import ProgressBar
try:
from cublas_ops import CublasLinear
except ImportError:
CublasLinear = None
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
@ -171,12 +175,14 @@ class MLP(nn.Module):
self, *,
width: int,
output_width: int = None,
drop_path_rate: float = 0.0
drop_path_rate: float = 0.0,
linear_func: nn.Module = nn.Linear
):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4)
self.c_proj = nn.Linear(width * 4, output_width if output_width is not None else width)
self.c_fc = linear_func(width, width * 4)
self.c_proj = linear_func(width * 4, output_width if output_width is not None else width)
self.gelu = nn.GELU()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
@ -233,16 +239,18 @@ class MultiheadCrossAttention(nn.Module):
data_width: Optional[int] = None,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
attention_mode: str = "sdpa"
attention_mode: str = "sdpa",
linear_func: nn.Module = nn.Linear
):
super().__init__()
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, bias=qkv_bias)
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = nn.Linear(width, width)
self.c_q = linear_func(width, width, bias=qkv_bias)
self.c_kv = linear_func(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = linear_func(width, width)
self.attention = QKVMultiheadCrossAttention(
heads=heads,
n_data=n_data,
@ -271,9 +279,11 @@ class ResidualCrossAttentionBlock(nn.Module):
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
attention_mode: str = "sdpa"
attention_mode: str = "sdpa",
linear_func: nn.Module = nn.Linear
):
super().__init__()
if data_width is None:
data_width = width
@ -288,12 +298,13 @@ class ResidualCrossAttentionBlock(nn.Module):
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
attention_mode=self.attention_mode
attention_mode=self.attention_mode,
linear_func=linear_func
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width)
self.mlp = MLP(width=width, linear_func=linear_func)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
@ -346,14 +357,15 @@ class MultiheadAttention(nn.Module):
qkv_bias: bool,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
drop_path_rate: float = 0.0,
linear_func: nn.Module = nn.Linear
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = nn.Linear(width, width)
self.c_qkv = linear_func(width, width * 3, bias=qkv_bias)
self.c_proj = linear_func(width, width)
self.attention = QKVMultiheadAttention(
heads=heads,
n_ctx=n_ctx,
@ -381,6 +393,7 @@ class ResidualAttentionBlock(nn.Module):
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0,
linear_func: nn.Module = nn.Linear
):
super().__init__()
self.attn = MultiheadAttention(
@ -390,10 +403,11 @@ class ResidualAttentionBlock(nn.Module):
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
drop_path_rate=drop_path_rate,
linear_func=linear_func
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate, linear_func=linear_func)
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
def forward(self, x: torch.Tensor):
@ -413,9 +427,11 @@ class Transformer(nn.Module):
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
drop_path_rate: float = 0.0,
linear_func: nn.Module = nn.Linear
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
@ -428,7 +444,8 @@ class Transformer(nn.Module):
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
drop_path_rate=drop_path_rate,
linear_func=linear_func
)
for _ in range(layers)
]
@ -453,13 +470,15 @@ class CrossAttentionDecoder(nn.Module):
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
attention_mode: str = "sdpa"
attention_mode: str = "sdpa",
linear_func: nn.Module = nn.Linear
):
super().__init__()
self.fourier_embedder = fourier_embedder
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
self.query_proj = linear_func(self.fourier_embedder.out_dim, width)
self.cross_attn_decoder = ResidualCrossAttentionBlock(
n_data=num_latents,
@ -467,11 +486,12 @@ class CrossAttentionDecoder(nn.Module):
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attention_mode=attention_mode
attention_mode=attention_mode,
linear_func=linear_func
)
self.ln_post = nn.LayerNorm(width)
self.output_proj = nn.Linear(width, out_channels)
self.output_proj = linear_func(width, out_channels)
self.label_type = label_type
def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
@ -535,12 +555,16 @@ class ShapeVAE(nn.Module):
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
attention_mode: str = "sdpa"
attention_mode: str = "sdpa",
cublas_ops: bool = False
):
super().__init__()
linear_func = CublasLinear if cublas_ops else nn.Linear
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.post_kl = nn.Linear(embed_dim, width)
self.post_kl = linear_func(embed_dim, width)
self.transformer = Transformer(
n_ctx=num_latents,
@ -549,7 +573,8 @@ class ShapeVAE(nn.Module):
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
drop_path_rate=drop_path_rate,
linear_func=linear_func
)
self.geo_decoder = CrossAttentionDecoder(
@ -561,7 +586,8 @@ class ShapeVAE(nn.Module):
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
attention_mode=attention_mode
attention_mode=attention_mode,
linear_func=linear_func
)
self.scale_factor = scale_factor

View File

@ -154,6 +154,7 @@ class Hunyuan3DDiTPipeline:
use_safetensors=None,
compile_args=None,
attention_mode="sdpa",
cublas_ops=False,
**kwargs,
):
# load config
@ -188,6 +189,9 @@ class Hunyuan3DDiTPipeline:
config['model']['params']['guidance_embed'] = True
config['model']['params']['attention_mode'] = attention_mode
config['vae']['params']['attention_mode'] = attention_mode
if cublas_ops:
config['vae']['params']['cublas_ops'] = True
with init_empty_weights():
model = instantiate_from_config(config['model'])

View File

@ -112,6 +112,7 @@ class Hy3DModelLoader:
"optional": {
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
"cublas_ops": ("BOOLEAN", {"default": False, "tooltip": "Enable optimized cublas linear layers, speeds up decoding: https://github.com/aredden/torch-cublas-hgemm"}),
}
}
@ -120,7 +121,7 @@ class Hy3DModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper"
def loadmodel(self, model, compile_args=None, attention_mode="sdpa"):
def loadmodel(self, model, compile_args=None, attention_mode="sdpa", cublas_ops=False):
device = mm.get_torch_device()
offload_device=mm.unet_offload_device()
@ -133,7 +134,8 @@ class Hy3DModelLoader:
device=device,
offload_device=offload_device,
compile_args=compile_args,
attention_mode=attention_mode)
attention_mode=attention_mode,
cublas_ops=cublas_ops)
return (pipe, vae,)