Fix mc_algo selection for flash_vdm

This commit is contained in:
kijai 2025-03-19 14:07:34 +02:00
parent 280ab59834
commit e74f261037
2 changed files with 5 additions and 1 deletions

View File

@ -66,6 +66,7 @@ class SurfaceExtractor:
class MCSurfaceExtractor(SurfaceExtractor): class MCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
print("MC Surface Extractor")
vertices, faces, normals, _ = measure.marching_cubes( vertices, faces, normals, _ = measure.marching_cubes(
grid_logit.cpu().numpy(), grid_logit.cpu().numpy(),
mc_level, mc_level,
@ -79,6 +80,7 @@ class MCSurfaceExtractor(SurfaceExtractor):
class DMCSurfaceExtractor(SurfaceExtractor): class DMCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, octree_resolution, **kwargs): def run(self, grid_logit, *, octree_resolution, **kwargs):
device = grid_logit.device device = grid_logit.device
print("DMC Surface Extractor")
if not hasattr(self, 'dmc'): if not hasattr(self, 'dmc'):
try: try:
from diso import DiffDMC from diso import DiffDMC

View File

@ -1213,7 +1213,9 @@ class Hy3DVAEDecode:
vae.to(device) vae.to(device)
vae.enable_flashvdm_decoder(enabled=enable_flash_vdm) vae.enable_flashvdm_decoder(
enabled=enable_flash_vdm,
mc_algo=mc_algo,)
latents = 1. / vae.scale_factor * latents latents = 1. / vae.scale_factor * latents
latents = vae(latents) latents = vae(latents)