mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-22 03:04:25 +08:00
Update vae.py
This commit is contained in:
parent
44b9d0d676
commit
89536cbd42
@ -636,13 +636,22 @@ class ShapeVAE(nn.Module):
|
|||||||
batch_logits = []
|
batch_logits = []
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
comfy_pbar = ProgressBar(xyz_samples.shape[0])
|
comfy_pbar = ProgressBar(xyz_samples.shape[0])
|
||||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
|
||||||
desc=f"MC Level {mc_level} Implicit Function:"):
|
|
||||||
if mc_algo == 'odc':
|
if mc_algo == 'odc':
|
||||||
imp_func = lambda xyz: torch.flatten(self.geo_decoder(repeat(xyz, "p c -> b p c", b=batch_size).to(latents.dtype), latents))
|
imp_func = lambda xyz: torch.flatten(self.geo_decoder(repeat(xyz, "p c -> b p c", b=batch_size).to(latents.dtype), latents))
|
||||||
vertices, faces = odc.extract_mesh(imp_func, num_grid = octree_resolution, isolevel=mc_level, batch_size=num_chunks, min_coord=bbox_min, max_coord=bbox_max)
|
vertices, faces = odc.extract_mesh(imp_func, num_grid = octree_resolution, isolevel=mc_level, batch_size=num_chunks, min_coord=bbox_min, max_coord=bbox_max)
|
||||||
comfy_pbar.update(num_chunks)
|
comfy_pbar.update(num_chunks)
|
||||||
|
|
||||||
|
vertices = vertices.detach().cpu().numpy()
|
||||||
|
faces = faces.detach().cpu().numpy()
|
||||||
|
outputs = [
|
||||||
|
Latent2MeshOutput(
|
||||||
|
mesh_v=vertices.astype(np.float32),
|
||||||
|
mesh_f=np.ascontiguousarray(faces)
|
||||||
|
)]
|
||||||
|
return outputs
|
||||||
else:
|
else:
|
||||||
|
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
||||||
|
desc=f"MC Level {mc_level} Implicit Function:"):
|
||||||
queries = xyz_samples[start: start + num_chunks, :].to(device)
|
queries = xyz_samples[start: start + num_chunks, :].to(device)
|
||||||
queries = queries.half()
|
queries = queries.half()
|
||||||
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||||
@ -654,16 +663,6 @@ class ShapeVAE(nn.Module):
|
|||||||
batch_logits.append(logits)
|
batch_logits.append(logits)
|
||||||
comfy_pbar.update(num_chunks)
|
comfy_pbar.update(num_chunks)
|
||||||
|
|
||||||
if mc_algo == 'odc':
|
|
||||||
vertices = vertices.detach().cpu().numpy()
|
|
||||||
faces = faces.detach().cpu().numpy()
|
|
||||||
outputs = [
|
|
||||||
Latent2MeshOutput(
|
|
||||||
mesh_v=vertices.astype(np.float32),
|
|
||||||
mesh_f=np.ascontiguousarray(faces)
|
|
||||||
)]
|
|
||||||
return outputs
|
|
||||||
else:
|
|
||||||
grid_logits = torch.cat(batch_logits, dim=1)
|
grid_logits = torch.cat(batch_logits, dim=1)
|
||||||
grid_logits = grid_logits.view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float()
|
grid_logits = grid_logits.view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user