diff --git a/hy3dgen/shapegen/models/vae.py b/hy3dgen/shapegen/models/vae.py index b9b13bf..0a85785 100755 --- a/hy3dgen/shapegen/models/vae.py +++ b/hy3dgen/shapegen/models/vae.py @@ -636,13 +636,22 @@ class ShapeVAE(nn.Module): batch_logits = [] batch_size = latents.shape[0] comfy_pbar = ProgressBar(xyz_samples.shape[0]) - for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), - desc=f"MC Level {mc_level} Implicit Function:"): - if mc_algo == 'odc': + if mc_algo == 'odc': imp_func = lambda xyz: torch.flatten(self.geo_decoder(repeat(xyz, "p c -> b p c", b=batch_size).to(latents.dtype), latents)) vertices, faces = odc.extract_mesh(imp_func, num_grid = octree_resolution, isolevel=mc_level, batch_size=num_chunks, min_coord=bbox_min, max_coord=bbox_max) comfy_pbar.update(num_chunks) - else: + + vertices = vertices.detach().cpu().numpy() + faces = faces.detach().cpu().numpy() + outputs = [ + Latent2MeshOutput( + mesh_v=vertices.astype(np.float32), + mesh_f=np.ascontiguousarray(faces) + )] + return outputs + else: + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), + desc=f"MC Level {mc_level} Implicit Function:"): queries = xyz_samples[start: start + num_chunks, :].to(device) queries = queries.half() batch_queries = repeat(queries, "p c -> b p c", b=batch_size) @@ -654,16 +663,6 @@ class ShapeVAE(nn.Module): batch_logits.append(logits) comfy_pbar.update(num_chunks) - if mc_algo == 'odc': - vertices = vertices.detach().cpu().numpy() - faces = faces.detach().cpu().numpy() - outputs = [ - Latent2MeshOutput( - mesh_v=vertices.astype(np.float32), - mesh_f=np.ascontiguousarray(faces) - )] - return outputs - else: grid_logits = torch.cat(batch_logits, dim=1) grid_logits = grid_logits.view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float()