mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-01-04 02:37:29 +08:00
242 lines
8.9 KiB
Python
242 lines
8.9 KiB
Python
import trimesh
|
|
import numpy as np
|
|
from .data_utils import discretize, undiscretize
|
|
|
|
|
|
def patchified_mesh(mesh: trimesh.Trimesh, special_token = -2, fix_orient=True):
|
|
sequence = []
|
|
unvisited = np.full(len(mesh.faces), True)
|
|
degrees = mesh.vertex_degree.copy()
|
|
|
|
# with fix_orient=True, the normal would be correct.
|
|
# but this may increase the difficulty for learning.
|
|
if fix_orient:
|
|
face_orient = {}
|
|
for ind, face in enumerate(mesh.faces):
|
|
v0, v1, v2 = face[0], face[1], face[2]
|
|
face_orient['{}-{}-{}'.format(v0, v1, v2)] = True
|
|
face_orient['{}-{}-{}'.format(v1, v2, v0)] = True
|
|
face_orient['{}-{}-{}'.format(v2, v0, v1)] = True
|
|
face_orient['{}-{}-{}'.format(v2, v1, v0)] = False
|
|
face_orient['{}-{}-{}'.format(v1, v0, v2)] = False
|
|
face_orient['{}-{}-{}'.format(v0, v2, v1)] = False
|
|
|
|
while sum(unvisited):
|
|
unvisited_faces = mesh.faces[unvisited]
|
|
|
|
# select the patch center
|
|
cur_face = unvisited_faces[0]
|
|
max_deg_vertex_id = np.argmax(degrees[cur_face])
|
|
max_deg_vertex = cur_face[max_deg_vertex_id]
|
|
|
|
# find all connected faces
|
|
selected_faces = []
|
|
for face_idx in mesh.vertex_faces[max_deg_vertex]:
|
|
if face_idx != -1 and unvisited[face_idx]:
|
|
face = mesh.faces[face_idx]
|
|
u, v = sorted([vertex for vertex in face if vertex != max_deg_vertex])
|
|
selected_faces.append([u, v, face_idx])
|
|
|
|
face_patch = set()
|
|
selected_faces = sorted(selected_faces)
|
|
|
|
# select the start vertex, select it if it only appears once (the start or end),
|
|
# else select the lowest index
|
|
cnt = {}
|
|
for u, v, _ in selected_faces:
|
|
cnt[u] = cnt.get(u, 0) + 1
|
|
cnt[v] = cnt.get(v, 0) + 1
|
|
starts = []
|
|
for vertex, num in cnt.items():
|
|
if num == 1:
|
|
starts.append(vertex)
|
|
start_idx = min(starts) if len(starts) else selected_faces[0][0]
|
|
|
|
res = [start_idx]
|
|
while len(res) <= len(selected_faces):
|
|
vertex = res[-1]
|
|
for u_i, v_i, face_idx_i in selected_faces:
|
|
if face_idx_i not in face_patch and vertex in (u_i, v_i):
|
|
u_i, v_i = (u_i, v_i) if vertex == u_i else (v_i, u_i)
|
|
res.append(v_i)
|
|
face_patch.add(face_idx_i)
|
|
break
|
|
|
|
if res[-1] == vertex:
|
|
break
|
|
|
|
if fix_orient and len(res) >= 2 and not face_orient['{}-{}-{}'.format(max_deg_vertex, res[0], res[1])]:
|
|
res = res[::-1]
|
|
|
|
# reduce the degree of related vertices and mark the visited faces
|
|
degrees[max_deg_vertex] = len(selected_faces) - len(res) + 1
|
|
for pos_idx, vertex in enumerate(res):
|
|
if pos_idx in [0, len(res) - 1]:
|
|
degrees[vertex] -= 1
|
|
else:
|
|
degrees[vertex] -= 2
|
|
for face_idx in face_patch:
|
|
unvisited[face_idx] = False
|
|
sequence.extend(
|
|
[mesh.vertices[max_deg_vertex]] +
|
|
[mesh.vertices[vertex_idx] for vertex_idx in res] +
|
|
[[special_token] * 3]
|
|
)
|
|
|
|
assert sum(degrees) == 0, 'All degrees should be zero'
|
|
|
|
return np.array(sequence)
|
|
|
|
|
|
|
|
def get_block_representation(
|
|
sequence,
|
|
block_size=8,
|
|
offset_size=16,
|
|
block_compressed=True,
|
|
special_token=-2,
|
|
use_special_block=True
|
|
):
|
|
'''
|
|
convert coordinates from Cartesian system to block indexes.
|
|
'''
|
|
special_block_base = block_size**3 + offset_size**3
|
|
# prepare coordinates
|
|
sp_mask = sequence != special_token
|
|
sp_mask = np.all(sp_mask, axis=1)
|
|
coords = sequence[sp_mask].reshape(-1, 3)
|
|
coords = discretize(coords)
|
|
|
|
# convert [x, y, z] to [block_id, offset_id]
|
|
block_id = coords // offset_size
|
|
block_id = block_id[:, 0] * block_size**2 + block_id[:, 1] * block_size + block_id[:, 2]
|
|
offset_id = coords % offset_size
|
|
offset_id = offset_id[:, 0] * offset_size**2 + offset_id[:, 1] * offset_size + offset_id[:, 2]
|
|
offset_id += block_size**3
|
|
block_coords = np.concatenate([block_id[..., None], offset_id[..., None]], axis=-1).astype(np.int64)
|
|
sequence[:, :2][sp_mask] = block_coords
|
|
sequence = sequence[:, :2]
|
|
|
|
# convert to codes
|
|
codes = []
|
|
cur_block_id = sequence[0, 0]
|
|
codes.append(cur_block_id)
|
|
for i in range(len(sequence)):
|
|
if sequence[i, 0] == special_token:
|
|
if not use_special_block:
|
|
codes.append(special_token)
|
|
cur_block_id = special_token
|
|
|
|
elif sequence[i, 0] == cur_block_id:
|
|
if block_compressed:
|
|
codes.append(sequence[i, 1])
|
|
else:
|
|
codes.extend([sequence[i, 0], sequence[i, 1]])
|
|
|
|
else:
|
|
if use_special_block and cur_block_id == special_token:
|
|
block_id = sequence[i, 0] + special_block_base
|
|
else:
|
|
block_id = sequence[i, 0]
|
|
codes.extend([block_id, sequence[i, 1]])
|
|
cur_block_id = block_id
|
|
|
|
codes = np.array(codes).astype(np.int64)
|
|
sequence = codes
|
|
|
|
return sequence.flatten()
|
|
|
|
|
|
def BPT_serialize(mesh: trimesh.Trimesh):
|
|
# serialize mesh with BPT
|
|
|
|
# 1. patchify faces into patches
|
|
sequence = patchified_mesh(mesh, special_token=-2)
|
|
|
|
# 2. convert coordinates to block-wise indexes
|
|
codes = get_block_representation(
|
|
sequence, block_size=8, offset_size=16,
|
|
block_compressed=True, special_token=-2, use_special_block=True
|
|
)
|
|
return codes
|
|
|
|
|
|
def decode_block(sequence, compressed=True, block_size=8, offset_size=16):
|
|
|
|
# decode from compressed representation
|
|
if compressed:
|
|
res = []
|
|
res_block = 0
|
|
for token_id in range(len(sequence)):
|
|
if block_size**3 + offset_size**3 > sequence[token_id] >= block_size**3:
|
|
res.append([res_block, sequence[token_id]])
|
|
elif block_size**3 > sequence[token_id] >= 0:
|
|
res_block = sequence[token_id]
|
|
else:
|
|
print('[Warning] too large offset idx!', token_id, sequence[token_id])
|
|
sequence = np.array(res)
|
|
|
|
block_id, offset_id = np.array_split(sequence, 2, axis=-1)
|
|
|
|
# from hash representation to xyz
|
|
coords = []
|
|
offset_id -= block_size**3
|
|
for i in [2, 1, 0]:
|
|
axis = (block_id // block_size**i) * offset_size + (offset_id // offset_size**i)
|
|
block_id %= block_size**i
|
|
offset_id %= offset_size**i
|
|
coords.append(axis)
|
|
|
|
coords = np.concatenate(coords, axis=-1) # (nf 3)
|
|
|
|
# back to continuous space
|
|
coords = undiscretize(coords)
|
|
|
|
return coords
|
|
|
|
|
|
def BPT_deserialize(sequence, block_size=8, offset_size=16, compressed=True, special_token=-2, use_special_block=True):
|
|
# decode codes back to coordinates
|
|
|
|
special_block_base = block_size**3 + offset_size**3
|
|
start_idx = 0
|
|
vertices = []
|
|
for i in range(len(sequence)):
|
|
sub_seq = []
|
|
if not use_special_block and (sequence[i] == special_token or i == len(sequence) - 1):
|
|
sub_seq = sequence[start_idx:i]
|
|
sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size)
|
|
start_idx = i + 1
|
|
|
|
elif use_special_block and \
|
|
(special_block_base <= sequence[i] < special_block_base + block_size**3 or i == len(sequence)-1):
|
|
if i != 0:
|
|
sub_seq = sequence[start_idx:i] if i != len(sequence) - 1 else sequence[start_idx: i+1]
|
|
if special_block_base <= sub_seq[0] < special_block_base + block_size**3:
|
|
sub_seq[0] -= special_block_base
|
|
sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size)
|
|
start_idx = i
|
|
|
|
if len(sub_seq):
|
|
center, sub_seq = sub_seq[0], sub_seq[1:]
|
|
for j in range(len(sub_seq) - 1):
|
|
vertices.extend([center.reshape(1, 3), sub_seq[j].reshape(1, 3), sub_seq[j+1].reshape(1, 3)])
|
|
|
|
# (nf, 3)
|
|
return np.concatenate(vertices, axis=0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# a simple demo for serialize and deserialize mesh with bpt
|
|
from data_utils import load_process_mesh, to_mesh
|
|
import torch
|
|
mesh = load_process_mesh('/path/to/your/mesh', quantization_bits=7)
|
|
mesh['faces'] = np.array(mesh['faces'])
|
|
mesh = to_mesh(mesh['vertices'], mesh['faces'], transpose=True)
|
|
mesh.export('gt.obj')
|
|
codes = BPT_serialize(mesh)
|
|
coordinates = BPT_deserialize(codes)
|
|
faces = torch.arange(1, len(coordinates) + 1).view(-1, 3)
|
|
mesh = to_mesh(coordinates, faces, transpose=False, post_process=False)
|
|
mesh.export('reconstructed.obj')
|