mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-08 20:34:28 +08:00
Merge pull request #71 from Easymode-ai/main
BPT node, seeded determinism, node sample count
This commit is contained in:
commit
d72f2e9f3f
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -6,9 +6,12 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import sys
|
||||
sys.path.append(r"C:\Remade\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-Hunyuan3DWrapper-main")
|
||||
import os
|
||||
custom_node_path = os.path.dirname(os.path.abspath(__file__))
|
||||
custom_node_path = os.path.abspath(os.path.join(custom_node_path, "..", "..", "..", "..", "..",".."))
|
||||
sys.path.append(custom_node_path)
|
||||
|
||||
from hy3dgen.shapegen.bpt.miche.michelangelo.models.tsal import asl_pl_module
|
||||
#from hy3dgen.shapegen.bpt.miche.michelangelo.models.tsal import asl_pl_module
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
@ -26,6 +29,10 @@ def get_obj_from_config(config):
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
|
||||
print(" custom path :")
|
||||
print(custom_node_path)
|
||||
print("\n")
|
||||
if "target" not in config:
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from .miche_conditioner import PointConditioner
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
from .data_utils import discretize
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
# helper functions
|
||||
|
||||
@ -182,6 +183,10 @@ class MeshTransformer(Module):
|
||||
|
||||
cache = None
|
||||
eos_iter = None
|
||||
|
||||
# ✅ Initialize ComfyUI progress bar
|
||||
pbar = ProgressBar(max_seq_len - curr_length)
|
||||
|
||||
# predict tokens auto-regressively
|
||||
for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position,
|
||||
desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False):
|
||||
@ -207,6 +212,9 @@ class MeshTransformer(Module):
|
||||
sample = torch.multinomial(probs, 1)
|
||||
codes, _ = pack([codes, sample], 'b *')
|
||||
|
||||
# ComfyUI progress bar
|
||||
pbar.update(1)
|
||||
|
||||
# Check if all sequences have encountered EOS at least once
|
||||
is_eos_codes = (codes == self.eos_token_id)
|
||||
if is_eos_codes.any(dim=-1).all():
|
||||
@ -216,6 +224,9 @@ class MeshTransformer(Module):
|
||||
# Once we've generated 20% more tokens than eos_iter, break out of the loop
|
||||
if codes.shape[-1] >= int(eos_iter * 1.2):
|
||||
break
|
||||
|
||||
# Ensure progress bar reaches 100% when loop completes
|
||||
#pbar.complete()
|
||||
|
||||
# mask out to padding anything after the first eos
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ pytorch-warmup
|
||||
torch_geometric
|
||||
torchtyping
|
||||
vector-quantize-pytorch>=1.12.8
|
||||
x-transformers>=1.26.6
|
||||
x-transformers==1.26.6
|
||||
tqdm
|
||||
matplotlib
|
||||
wandb
|
||||
@ -26,4 +26,7 @@ setuptools
|
||||
pytorch_lightning
|
||||
mesh2sdf
|
||||
numpy
|
||||
point-cloud-utils
|
||||
point-cloud-utils
|
||||
transformers==4.48.0
|
||||
networkx==3.4.2
|
||||
deepspeed==0.16.3
|
||||
|
||||
@ -66,20 +66,23 @@ def apply_normalize(mesh):
|
||||
|
||||
|
||||
|
||||
def sample_pc(trimesh, pc_num, with_normal=False):
|
||||
mesh = apply_normalize(trimesh)
|
||||
def sample_pc(mesh, pc_num, with_normal=False, seed=1234, samples=50000):
|
||||
mesh = apply_normalize(mesh)
|
||||
|
||||
if not with_normal:
|
||||
points, _ = mesh.sample(pc_num, return_index=True)
|
||||
return points
|
||||
|
||||
points, face_idx = mesh.sample(50000, return_index=True)
|
||||
points, face_idx = trimesh.sample.sample_surface(mesh=mesh, count=samples, seed=seed)
|
||||
#points, face_idx = mesh.sample(50000, return_index=True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
|
||||
|
||||
# random sample point cloud
|
||||
np.random.seed(seed)
|
||||
ind = np.random.choice(pc_normal.shape[0], pc_num, replace=False)
|
||||
pc_normal = pc_normal[ind]
|
||||
|
||||
|
||||
return pc_normal
|
||||
|
||||
|
||||
@ -158,13 +158,13 @@ def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutpu
|
||||
|
||||
return mesh
|
||||
|
||||
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1):
|
||||
def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal: bool = True, temperature: float = 0.5, batch_size: int = 1, pc_num: int = 4096, seed: int = 1234, samples: int = 50000):
|
||||
from .bpt.model import data_utils
|
||||
from .bpt.model.model import MeshTransformer
|
||||
from .bpt.model.serializaiton import BPT_deserialize
|
||||
from .bpt.utils import sample_pc, joint_filter
|
||||
|
||||
pc_normal = sample_pc(mesh, pc_num=8192, with_normal=with_normal)
|
||||
pc_normal = sample_pc(mesh, pc_num=pc_num, with_normal=with_normal, seed=seed, samples=samples)
|
||||
|
||||
pc_normal = pc_normal[None, :, :] if len(pc_normal.shape) == 2 else pc_normal
|
||||
|
||||
@ -181,9 +181,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
|
||||
model_path = os.path.join(comfyui_dir, 'bpt/bpt-8-16-500m.pt')
|
||||
print(model_path)
|
||||
model.load(model_path)
|
||||
model = model.eval()
|
||||
model = model.half()
|
||||
model = model.cuda()
|
||||
model = model.eval().cuda().half()
|
||||
|
||||
import torch
|
||||
pc_tensor = torch.from_numpy(pc_normal).cuda().half()
|
||||
@ -196,9 +194,9 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
|
||||
filter_kwargs=dict(k=50, p=0.95),
|
||||
return_codes=True,
|
||||
temperature=temperature,
|
||||
batch_size=batch_size
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
coords = []
|
||||
try:
|
||||
for i in range(len(codes)):
|
||||
@ -214,8 +212,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
|
||||
except:
|
||||
coords.append(np.zeros(3, 3))
|
||||
|
||||
# convert coordinates to mesh
|
||||
vertices = coords[0]
|
||||
vertices = coords[i]
|
||||
faces = torch.arange(1, len(vertices) + 1).view(-1, 3)
|
||||
|
||||
# Move to CPU
|
||||
@ -225,6 +222,7 @@ def bpt_remesh(self, mesh: trimesh.Trimesh, verbose: bool = False, with_normal:
|
||||
|
||||
return data_utils.to_mesh(vertices, faces, transpose=False, post_process=True)
|
||||
|
||||
|
||||
class BptMesh:
|
||||
def __call__(
|
||||
self,
|
||||
@ -232,9 +230,12 @@ class BptMesh:
|
||||
temperature: float = 0.5,
|
||||
batch_size: int = 1,
|
||||
with_normal: bool = True,
|
||||
verbose: bool = False
|
||||
verbose: bool = False,
|
||||
pc_num: int = 4096,
|
||||
seed: int = 1234,
|
||||
samples: int = 50000
|
||||
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
|
||||
mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal)
|
||||
mesh = bpt_remesh(self, mesh=mesh, temperature=temperature, batch_size=batch_size, with_normal=with_normal, pc_num=pc_num, seed=seed, samples=samples)
|
||||
return mesh
|
||||
|
||||
class FaceReducer:
|
||||
|
||||
11
nodes.py
11
nodes.py
@ -1279,8 +1279,10 @@ class Hy3DBPT:
|
||||
"required": {
|
||||
"trimesh": ("TRIMESH",),
|
||||
"enable_bpt": ("BOOLEAN", {"default": True}),
|
||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"temperature": ("FLOAT", {"default": 0.5}),
|
||||
"batch_size": ("INT", {"default": 1}),
|
||||
"pc_num": ("INT", {"default": 4096, "min": 1024, "max": 8192, "step": 1024}),
|
||||
"samples": ("INT", {"default": 100000})
|
||||
},
|
||||
}
|
||||
|
||||
@ -1290,12 +1292,13 @@ class Hy3DBPT:
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
DESCRIPTION = "BPT the mesh using bpt: https://github.com/whaohan/bpt"
|
||||
|
||||
def bpt(self, trimesh, enable_bpt, temperature, batch_size):
|
||||
|
||||
def bpt(self, trimesh, enable_bpt, temperature, pc_num, seed, samples):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
new_mesh = trimesh.copy()
|
||||
if enable_bpt:
|
||||
from .hy3dgen.shapegen.postprocessors import BptMesh
|
||||
new_mesh = BptMesh()(new_mesh, with_normal=True, temperature=temperature, batch_size=batch_size)
|
||||
new_mesh = BptMesh()(new_mesh, with_normal=True, temperature=temperature, batch_size=1, pc_num=pc_num, verbose=False, seed=seed, samples=samples)
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user