Merge pull request #71 from Easymode-ai/main

BPT node, seeded determinism, node sample count
This commit is contained in:
Jukka Seppänen 2025-02-20 12:36:06 +02:00 committed by GitHub
commit d72f2e9f3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 50 additions and 22 deletions

View File

@ -6,9 +6,12 @@ import torch
import torch.distributed as dist
import sys
sys.path.append(r"C:\Remade\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-Hunyuan3DWrapper-main")
import os
custom_node_path = os.path.dirname(os.path.abspath(__file__))
custom_node_path = os.path.abspath(os.path.join(custom_node_path, "..", "..", "..", "..", "..",".."))
sys.path.append(custom_node_path)
from hy3dgen.shapegen.bpt.miche.michelangelo.models.tsal import asl_pl_module
#from hy3dgen.shapegen.bpt.miche.michelangelo.models.tsal import asl_pl_module
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
@ -26,6 +29,10 @@ def get_obj_from_config(config):
def instantiate_from_config(config, **kwargs):
print(" custom path :")
print(custom_node_path)
print("\n")
if "target" not in config:
raise KeyError("Expected key `target` to instantiate.")

View File

@ -18,6 +18,7 @@ from .miche_conditioner import PointConditioner
from functools import partial
from tqdm import tqdm
from .data_utils import discretize
from comfy.utils import ProgressBar
# helper functions
@ -182,6 +183,10 @@ class MeshTransformer(Module):
cache = None
eos_iter = None
# ✅ Initialize ComfyUI progress bar
pbar = ProgressBar(max_seq_len - curr_length)
# predict tokens auto-regressively
for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position,
desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False):
@ -207,6 +212,9 @@ class MeshTransformer(Module):
sample = torch.multinomial(probs, 1)
codes, _ = pack([codes, sample], 'b *')
# ComfyUI progress bar
pbar.update(1)
# Check if all sequences have encountered EOS at least once
is_eos_codes = (codes == self.eos_token_id)
if is_eos_codes.any(dim=-1).all():
@ -216,6 +224,9 @@ class MeshTransformer(Module):
# Once we've generated 20% more tokens than eos_iter, break out of the loop
if codes.shape[-1] >= int(eos_iter * 1.2):
break
# Ensure progress bar reaches 100% when loop completes
#pbar.complete()
# mask out to padding anything after the first eos

View File

@ -9,7 +9,7 @@ pytorch-warmup
torch_geometric
torchtyping
vector-quantize-pytorch>=1.12.8
x-transformers>=1.26.6
x-transformers==1.26.6
tqdm
matplotlib
wandb
@ -26,4 +26,7 @@ setuptools
pytorch_lightning
mesh2sdf
numpy
point-cloud-utils
point-cloud-utils
transformers==4.48.0
networkx==3.4.2
deepspeed==0.16.3

View File

@ -66,20 +66,23 @@ def apply_normalize(mesh):
def sample_pc(trimesh, pc_num, with_normal=False):
mesh = apply_normalize(trimesh)
def sample_pc(mesh, pc_num, with_normal=False, seed=1234, samples=50000):
mesh = apply_normalize(mesh)
if not with_normal:
points, _ = mesh.sample(pc_num, return_index=True)
return points
points, face_idx = mesh.sample(50000, return_index=True)
points, face_idx = trimesh.sample.sample_surface(mesh=mesh, count=samples, seed=seed)
#points, face_idx = mesh.sample(50000, return_index=True)
normals = mesh.face_normals[face_idx]
pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
# random sample point cloud
np.random.seed(seed)
ind = np.random.choice(pc_normal.shape[0], pc_num, replace=False)
pc_normal = pc_normal[ind]
return pc_normal

View File

@ -158,13 +158,13 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu
return mesh
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1):
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1, pc_num: int = 4096, seed: int = 1234, samples: int = 50000):
from .bpt.model import data_utils
from .bpt.model.model import MeshTransformer
from .bpt.model.serializaiton import BPT_deserialize
from .bpt.utils import sample_pc, joint_filter
pc_normal = sample_pc(mesh, pc_num=8192, with_normal=with_normal)
pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal, seed=seed, samples=samples)
pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal
@ -181,9 +181,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt')
print(model_path)
model.load(model_path)
model = model.eval()
model = model.half()
model = model.cuda()
model = model.eval().cuda().half()
import torch
pc_tensor = torch.from_numpy(pc_normal).cuda().half()
@ -196,9 +194,9 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
filter_kwargs=dict(k=50, p=0.95),
return_codes=True,
temperature=temperature,
batch_size=batch_size
batch_size=batch_size,
)
coords = []
try:
for i in range(len(codes)):
@ -214,8 +212,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
except:
coords.append(np.zeros(3, 3))
# convert coordinates to mesh
vertices = coords[0]
vertices = coords[i]
faces = torch.arange(1, len(vertices) + 1).view(-1, 3)
# Move to CPU
@ -225,6 +222,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
return data_utils.to_mesh(vertices, faces, transpose=False, post_process=True)
class BptMesh:
def __call__(
self,
@ -232,9 +230,12 @@ class BptMesh:
temperature: float = 0.5,
batch_size: int = 1,
with_normal: bool = True,
verbose: bool = False
verbose: bool = False,
pc_num: int = 4096,
seed: int = 1234,
samples: int = 50000
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal)
mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal, pc_num=pc_num, seed=seed, samples=samples)
return mesh
class FaceReducer:

View File

@ -1279,8 +1279,10 @@ class Hy3DBPT:
"required": {
"trimesh": ("TRIMESH",),
"enable_bpt": ("BOOLEAN", {"default": True}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
"temperature": ("FLOAT", {"default": 0.5}),
"batch_size": ("INT", {"default": 1}),
"pc_num": ("INT", {"default": 4096, "min": 1024, "max": 8192, "step": 1024}),
"samples": ("INT", {"default": 100000})
},
}
@ -1290,12 +1292,13 @@ class Hy3DBPT:
CATEGORY = "Hunyuan3DWrapper"
DESCRIPTION = "BPT the mesh using bpt: https://github.com/whaohan/bpt"
def bpt(self, trimesh, enable_bpt, temperature, batch_size):
def bpt(self, trimesh, enable_bpt, temperature, pc_num, seed, samples):
mm.unload_all_models()
mm.soft_empty_cache()
new_mesh = trimesh.copy()
if enable_bpt:
from .hy3dgen.shapegen.postprocessors import BptMesh
new_mesh = BptMesh()(new_mesh, with_normal=True, temperature=temperature, batch_size=batch_size)
new_mesh = BptMesh()(new_mesh, with_normal=True, temperature=temperature, batch_size=1, pc_num=pc_num, verbose=False, seed=seed, samples=samples)
mm.unload_all_models()
mm.soft_empty_cache()