From a9c6b5cbf06e3e128e6705d3a502e8e684264059 Mon Sep 17 00:00:00 2001 From: Easymode <76738305+Easymode-ai@users.noreply.github.com> Date: Wed, 19 Feb 2025 00:35:59 +0000 Subject: [PATCH] added bpt --- hy3dgen/shapegen/postprocessors.py | 77 ++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/hy3dgen/shapegen/postprocessors.py b/hy3dgen/shapegen/postprocessors.py index f392a9e..cc5f75f 100755 --- a/hy3dgen/shapegen/postprocessors.py +++ b/hy3dgen/shapegen/postprocessors.py @@ -158,6 +158,83 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu return mesh +def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, max_seq_len:int=10000, cond_dim:int=768, pc_num: int=8192, with_normal: bool = True, kwarg_k: int = 50, kwarg_p: float = 0.95): + from .bpt.model import data_utils + from .bpt.model.model import MeshTransformer + from .bpt.model.serializaiton import BPT_deserialize + from .bpt.utils import sample_pc, joint_filter + + pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal) + + pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal + + from torch.serialization import add_safe_globals + from deepspeed.runtime.fp16.loss_scaler import LossScaler + from deepspeed.runtime.zero.config import ZeroStageEnum + from deepspeed.utils.tensor_fragment import fragment_address + + add_safe_globals([LossScaler, fragment_address, ZeroStageEnum]) + + model = MeshTransformer(cond_dim=cond_dim, max_seq_len=max_seq_len) + + comfyui_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt') + print(model_path) + model.load(model_path) + model = model.eval() + model = model.half() + model = model.cuda() + + import torch + pc_tensor = torch.from_numpy(pc_normal).cuda().half() + if len(pc_tensor.shape) == 2: + pc_tensor = pc_tensor.unsqueeze(0) + + codes = model.generate( + pc=pc_tensor, + filter_logits_fn=joint_filter, + filter_kwargs=dict(k=50, p=0.95), + return_codes=True, + ) + + coords = [] + try: + for i in range(len(codes)): + code = codes[i] + code = code[code != model.pad_id].cpu().numpy() + vertices = BPT_deserialize( + code, + block_size=model.block_size, + offset_size=model.offset_size, + use_special_block=model.use_special_block, + ) + coords.append(vertices) + except: + coords.append(np.zeros(3, 3)) + + # convert coordinates to mesh + vertices = coords[0] + faces = torch.arange(1, len(vertices) + 1).view(-1, 3) + + # Move to CPU + faces = faces.cpu().numpy() + + del model + + return data_utils.to_mesh(vertices, faces, transpose=False, post_process=True) + +class BptMesh: + def __call__( + self, + mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str], + max_seq_len: int = 10000, + cond_dim: int = 768, + kwarg_k: int = 50, + kwarg_p: float = 0.95, + verbose: bool = False + ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]: + mesh = bpt_remesh(self, mesh=mesh, cond_dim=cond_dim, max_seq_len=max_seq_len, kwarg_k=kwarg_k, kwarg_p=kwarg_p) + return mesh class FaceReducer: def __call__(