mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-16 08:14:25 +08:00
91 lines
2.6 KiB
Python
91 lines
2.6 KiB
Python
import torch
|
|
import os
|
|
from torch import nn
|
|
from beartype import beartype
|
|
from ..miche.encode import load_model
|
|
from ..miche.michelangelo.models.tsal import asl_pl_module
|
|
|
|
# helper functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
def default(*values):
|
|
for value in values:
|
|
if exists(value):
|
|
return value
|
|
return None
|
|
|
|
|
|
# point-cloud encoder from Michelangelo
|
|
@beartype
|
|
class PointConditioner(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim_latent = None,
|
|
model_name = 'miche-256-feature',
|
|
cond_dim = 768,
|
|
freeze = True,
|
|
):
|
|
super().__init__()
|
|
|
|
# open-source version of miche
|
|
if model_name == 'miche-256-feature':
|
|
dir = os.path.dirname(os.path.abspath(__file__))
|
|
ckpt_path = None#os.path.join(dir, '..\shapevae-256.ckpt')
|
|
model_path = os.path.join(dir, '..\shapevae-256.yaml')
|
|
config_path = model_path
|
|
|
|
self.feature_dim = 1024 # embedding dimension
|
|
self.cond_length = 257 # length of embedding
|
|
self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path)
|
|
|
|
# additional layers to connect miche and GPT
|
|
self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim)
|
|
self.cond_proj = nn.Linear(cond_dim, self.feature_dim)
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# whether to finetuen point-cloud encoder
|
|
if freeze:
|
|
for parameter in self.point_encoder.parameters():
|
|
parameter.requires_grad = False
|
|
|
|
self.freeze = freeze
|
|
self.model_name = model_name
|
|
self.dim_latent = default(dim_latent, self.feature_dim)
|
|
|
|
self.register_buffer('_device_param', torch.tensor(0.), persistent = False)
|
|
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.buffers()).device
|
|
|
|
|
|
def embed_pc(self, pc_normal):
|
|
# encode point cloud to embeddings
|
|
if self.model_name == 'miche-256-feature':
|
|
point_feature = self.point_encoder.encode_latents(pc_normal)
|
|
pc_embed_head = self.cond_head_proj(point_feature[:, 0:1])
|
|
pc_embed = self.cond_proj(point_feature[:, 1:])
|
|
pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1)
|
|
|
|
return pc_embed
|
|
|
|
|
|
def forward(
|
|
self,
|
|
pc = None,
|
|
pc_embeds = None,
|
|
):
|
|
if pc_embeds is None:
|
|
pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype))
|
|
|
|
assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings'
|
|
|
|
return pc_embeds
|
|
|