version bump

This commit is contained in:
kijai 2025-03-21 20:52:56 +02:00
parent 875b0b0420
commit 921b0d78a9
5 changed files with 23 additions and 10 deletions

View File

@ -25,6 +25,7 @@ from .attention_blocks import CrossAttentionDecoder
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
from ...utils import logger
from comfy.utils import ProgressBar
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
device = input_tensor.device
@ -318,7 +319,7 @@ class FlashVDMVolumeDecoding:
for i, resolution in enumerate(resolutions[1:]):
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
#logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
# 1. generate query points
if isinstance(bounds, float):
@ -354,6 +355,7 @@ class FlashVDMVolumeDecoding:
)
batch_logits = []
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
comfy_pbar = ProgressBar(xyz_samples.shape[0])
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
queries = xyz_samples[start: start + num_batchs, :]
@ -362,13 +364,17 @@ class FlashVDMVolumeDecoding:
processor.topk = True
logits = geo_decoder(queries=queries, latents=batch_latents)
batch_logits.append(logits)
comfy_pbar.update(num_chunks)
grid_logits = torch.cat(batch_logits, dim=0).reshape(
mini_grid_num, mini_grid_num, mini_grid_num,
mini_grid_size, mini_grid_size,
mini_grid_size
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
(batch_size, grid_size[0], grid_size[1], grid_size[2])
)
)
for octree_depth_now in resolutions[1:]:
grid_size = np.array([octree_depth_now + 1] * 3)

View File

@ -112,6 +112,7 @@ class ImageEncoder(nn.Module):
image = (image - low) / (high - low)
image = image.to(self.model.device, dtype=self.model.dtype)
print("image shape", image.shape)
if mask is not None:
mask = mask.to(image)

View File

@ -206,8 +206,8 @@ class Hunyuan3DDiTPipeline:
config['model']['params']['attention_mode'] = attention_mode
#config['vae']['params']['attention_mode'] = attention_mode
if cublas_ops:
config['vae']['params']['cublas_ops'] = True
#if cublas_ops:
# config['vae']['params']['cublas_ops'] = True
with init_empty_weights():
model = instantiate_from_config(config['model'])

View File

@ -1086,6 +1086,7 @@ class Hy3DGenerateMesh:
"optional": {
"mask": ("MASK", ),
"scheduler": (["FlowMatchEulerDiscreteScheduler", "ConsistencyFlowMatchEulerDiscreteScheduler"],),
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Offloads the model to the offload device once the process is done."}),
}
}
@ -1094,7 +1095,8 @@ class Hy3DGenerateMesh:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None, scheduler="FlowMatchEulerDiscreteScheduler"):
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None,
scheduler="FlowMatchEulerDiscreteScheduler", force_offload=True):
mm.unload_all_models()
mm.soft_empty_cache()
@ -1136,8 +1138,9 @@ class Hy3DGenerateMesh:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.to(offload_device)
if not force_offload:
pipeline.to(offload_device)
return (latents, )
@ -1254,6 +1257,8 @@ class Hy3DVAEDecode:
},
"optional": {
"enable_flash_vdm": ("BOOLEAN", {"default": True}),
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Offloads the model to the offload device once the process is done."}),
}
}
@ -1262,7 +1267,7 @@ class Hy3DVAEDecode:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo, enable_flash_vdm=True):
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo, enable_flash_vdm=True, force_offload=True):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -1283,7 +1288,8 @@ class Hy3DVAEDecode:
octree_resolution=octree_resolution,
mc_algo=mc_algo,
)[0]
vae.to(offload_device)
if force_offload:
vae.to(offload_device)
outputs.mesh_f = outputs.mesh_f[:, ::-1]
mesh_output = Trimesh.Trimesh(outputs.mesh_v, outputs.mesh_f)

View File

@ -1,7 +1,7 @@
[project]
name = "comfyui-hunyan3dwrapper"
description = "Wrapper nodes for https://github.com/Tencent/Hunyuan3D-2, additional installation steps needed, please check the github repository"
version = "1.0.4"
version = "1.0.5"
license = {file = "LICENSE"}
dependencies = ["trimesh", "diffusers>=0.31.0","accelerate","huggingface_hub","einops","opencv-python","transformers","xatlas","pymeshlab","pygltflib","scikit-learn","scikit-image","pybind11"]