mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-10 05:14:29 +08:00
Added Node Progress Bar
This commit is contained in:
parent
f5408c9493
commit
2b44651d83
@ -18,6 +18,7 @@ from .miche_conditioner import PointConditioner
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .data_utils import discretize
|
from .data_utils import discretize
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
@ -182,6 +183,10 @@ class MeshTransformer(Module):
|
|||||||
|
|
||||||
cache = None
|
cache = None
|
||||||
eos_iter = None
|
eos_iter = None
|
||||||
|
|
||||||
|
# ✅ Initialize ComfyUI progress bar
|
||||||
|
pbar = ProgressBar(max_seq_len - curr_length)
|
||||||
|
|
||||||
# predict tokens auto-regressively
|
# predict tokens auto-regressively
|
||||||
for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position,
|
for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position,
|
||||||
desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False):
|
desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False):
|
||||||
@ -207,6 +212,9 @@ class MeshTransformer(Module):
|
|||||||
sample = torch.multinomial(probs, 1)
|
sample = torch.multinomial(probs, 1)
|
||||||
codes, _ = pack([codes, sample], 'b *')
|
codes, _ = pack([codes, sample], 'b *')
|
||||||
|
|
||||||
|
# ComfyUI progress bar
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
# Check if all sequences have encountered EOS at least once
|
# Check if all sequences have encountered EOS at least once
|
||||||
is_eos_codes = (codes == self.eos_token_id)
|
is_eos_codes = (codes == self.eos_token_id)
|
||||||
if is_eos_codes.any(dim=-1).all():
|
if is_eos_codes.any(dim=-1).all():
|
||||||
@ -216,6 +224,9 @@ class MeshTransformer(Module):
|
|||||||
# Once we've generated 20% more tokens than eos_iter, break out of the loop
|
# Once we've generated 20% more tokens than eos_iter, break out of the loop
|
||||||
if codes.shape[-1] >= int(eos_iter * 1.2):
|
if codes.shape[-1] >= int(eos_iter * 1.2):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Ensure progress bar reaches 100% when loop completes
|
||||||
|
pbar.complete()
|
||||||
|
|
||||||
# mask out to padding anything after the first eos
|
# mask out to padding anything after the first eos
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user