diff --git a/hy3dgen/shapegen/postprocessors.py b/hy3dgen/shapegen/postprocessors.py index cc5f75f..175dfc3 100755 --- a/hy3dgen/shapegen/postprocessors.py +++ b/hy3dgen/shapegen/postprocessors.py @@ -158,13 +158,13 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu return mesh -def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, max_seq_len:int=10000, cond_dim:int=768, pc_num: int=8192, with_normal: bool = True, kwarg_k: int = 50, kwarg_p: float = 0.95): +def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1): from .bpt.model import data_utils from .bpt.model.model import MeshTransformer from .bpt.model.serializaiton import BPT_deserialize from .bpt.utils import sample_pc, joint_filter - pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal) + pc_normal = sample_pc(mesh, with_normal=with_normal) pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal @@ -175,7 +175,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, max_seq_len:i add_safe_globals([LossScaler, fragment_address, ZeroStageEnum]) - model = MeshTransformer(cond_dim=cond_dim, max_seq_len=max_seq_len) + model = MeshTransformer() comfyui_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt') @@ -195,6 +195,8 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, max_seq_len:i filter_logits_fn=joint_filter, filter_kwargs=dict(k=50, p=0.95), return_codes=True, + temperature=temperature, + batch_size=batch_size ) coords = [] @@ -227,13 +229,12 @@ class BptMesh: def __call__( self, mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str], - max_seq_len: int = 10000, - cond_dim: int = 768, - kwarg_k: int = 50, - kwarg_p: float = 0.95, + temperature: float = 0.5, + batch_size: int = 1, + with_normal: bool = True, verbose: bool = False ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]: - mesh = bpt_remesh(self, mesh=mesh, cond_dim=cond_dim, max_seq_len=max_seq_len, kwarg_k=kwarg_k, kwarg_p=kwarg_p) + mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal) return mesh class FaceReducer: