diff --git a/hy3dgen/shapegen/postprocessors.py b/hy3dgen/shapegen/postprocessors.py index 4e554a9..6864bc3 100755 --- a/hy3dgen/shapegen/postprocessors.py +++ b/hy3dgen/shapegen/postprocessors.py @@ -158,13 +158,13 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu return mesh -def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1): +def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1, pc_num: int = 4096): from .bpt.model import data_utils from .bpt.model.model import MeshTransformer from .bpt.model.serializaiton import BPT_deserialize from .bpt.utils import sample_pc, joint_filter - pc_normal = sample_pc(mesh, pc_num=8192, with_normal=with_normal) + pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal) pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal @@ -181,9 +181,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt') print(model_path) model.load(model_path) - model = model.eval() - model = model.half() - model = model.cuda() + model = model.eval().cuda().half() import torch pc_tensor = torch.from_numpy(pc_normal).cuda().half() @@ -198,7 +196,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: temperature=temperature, batch_size=batch_size ) - + coords = [] try: for i in range(len(codes)): @@ -214,8 +212,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: except: coords.append(np.zeros(3, 3)) - # convert coordinates to mesh - vertices = coords[0] + vertices = coords[i] faces = torch.arange(1, len(vertices) + 1).view(-1, 3) # Move to CPU @@ -225,6 +222,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: return data_utils.to_mesh(vertices, faces, transpose=False, post_process=True) + class BptMesh: def __call__( self, @@ -232,9 +230,10 @@ class BptMesh: temperature: float = 0.5, batch_size: int = 1, with_normal: bool = True, - verbose: bool = False + verbose: bool = False, + pc_num: int = 4096 ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]: - mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal) + mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal, pc_num=pc_num) return mesh class FaceReducer: