Updated BPT code

This commit is contained in:
Easymode 2025-02-19 17:18:06 +00:00 committed by GitHub
parent d34f251c36
commit 2ff95dbd71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -158,13 +158,13 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu
return mesh
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1):
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1, pc_num: int = 4096):
from .bpt.model import data_utils
from .bpt.model.model import MeshTransformer
from .bpt.model.serializaiton import BPT_deserialize
from .bpt.utils import sample_pc, joint_filter
pc_normal = sample_pc(mesh, pc_num=8192, with_normal=with_normal)
pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal)
pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal
@ -181,9 +181,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt')
print(model_path)
model.load(model_path)
model = model.eval()
model = model.half()
model = model.cuda()
model = model.eval().cuda().half()
import torch
pc_tensor = torch.from_numpy(pc_normal).cuda().half()
@ -198,7 +196,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
temperature=temperature,
batch_size=batch_size
)
coords = []
try:
for i in range(len(codes)):
@ -214,8 +212,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
except:
coords.append(np.zeros(3, 3))
# convert coordinates to mesh
vertices = coords[0]
vertices = coords[i]
faces = torch.arange(1, len(vertices) + 1).view(-1, 3)
# Move to CPU
@ -225,6 +222,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
return data_utils.to_mesh(vertices, faces, transpose=False, post_process=True)
class BptMesh:
def __call__(
self,
@ -232,9 +230,10 @@ class BptMesh:
temperature: float = 0.5,
batch_size: int = 1,
with_normal: bool = True,
verbose: bool = False
verbose: bool = False,
pc_num: int = 4096
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal)
mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal, pc_num=pc_num)
return mesh
class FaceReducer: