mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-05-18 17:37:17 +08:00
added bpt
This commit is contained in:
parent
2a742b0617
commit
a9c6b5cbf0
@ -158,6 +158,83 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu
|
|||||||
|
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, max_seq_len:int=10000, cond_dim:int=768, pc_num: int=8192, with_normal: bool = True, kwarg_k: int = 50, kwarg_p: float = 0.95):
|
||||||
|
from .bpt.model import data_utils
|
||||||
|
from .bpt.model.model import MeshTransformer
|
||||||
|
from .bpt.model.serializaiton import BPT_deserialize
|
||||||
|
from .bpt.utils import sample_pc, joint_filter
|
||||||
|
|
||||||
|
pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal)
|
||||||
|
|
||||||
|
pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal
|
||||||
|
|
||||||
|
from torch.serialization import add_safe_globals
|
||||||
|
from deepspeed.runtime.fp16.loss_scaler import LossScaler
|
||||||
|
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||||
|
from deepspeed.utils.tensor_fragment import fragment_address
|
||||||
|
|
||||||
|
add_safe_globals([LossScaler, fragment_address, ZeroStageEnum])
|
||||||
|
|
||||||
|
model = MeshTransformer(cond_dim=cond_dim, max_seq_len=max_seq_len)
|
||||||
|
|
||||||
|
comfyui_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt')
|
||||||
|
print(model_path)
|
||||||
|
model.load(model_path)
|
||||||
|
model = model.eval()
|
||||||
|
model = model.half()
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
import torch
|
||||||
|
pc_tensor = torch.from_numpy(pc_normal).cuda().half()
|
||||||
|
if len(pc_tensor.shape) == 2:
|
||||||
|
pc_tensor = pc_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
codes = model.generate(
|
||||||
|
pc=pc_tensor,
|
||||||
|
filter_logits_fn=joint_filter,
|
||||||
|
filter_kwargs=dict(k=50, p=0.95),
|
||||||
|
return_codes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
coords = []
|
||||||
|
try:
|
||||||
|
for i in range(len(codes)):
|
||||||
|
code = codes[i]
|
||||||
|
code = code[code != model.pad_id].cpu().numpy()
|
||||||
|
vertices = BPT_deserialize(
|
||||||
|
code,
|
||||||
|
block_size=model.block_size,
|
||||||
|
offset_size=model.offset_size,
|
||||||
|
use_special_block=model.use_special_block,
|
||||||
|
)
|
||||||
|
coords.append(vertices)
|
||||||
|
except:
|
||||||
|
coords.append(np.zeros(3, 3))
|
||||||
|
|
||||||
|
# convert coordinates to mesh
|
||||||
|
vertices = coords[0]
|
||||||
|
faces = torch.arange(1, len(vertices) + 1).view(-1, 3)
|
||||||
|
|
||||||
|
# Move to CPU
|
||||||
|
faces = faces.cpu().numpy()
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
return data_utils.to_mesh(vertices, faces, transpose=False, post_process=True)
|
||||||
|
|
||||||
|
class BptMesh:
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||||
|
max_seq_len: int = 10000,
|
||||||
|
cond_dim: int = 768,
|
||||||
|
kwarg_k: int = 50,
|
||||||
|
kwarg_p: float = 0.95,
|
||||||
|
verbose: bool = False
|
||||||
|
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
|
||||||
|
mesh = bpt_remesh(self, mesh=mesh, cond_dim=cond_dim, max_seq_len=max_seq_len, kwarg_k=kwarg_k, kwarg_p=kwarg_p)
|
||||||
|
return mesh
|
||||||
|
|
||||||
class FaceReducer:
|
class FaceReducer:
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user