diff --git a/hy3dgen/shapegen/bpt/model/model.py b/hy3dgen/shapegen/bpt/model/model.py index 8832060..c1260e0 100644 --- a/hy3dgen/shapegen/bpt/model/model.py +++ b/hy3dgen/shapegen/bpt/model/model.py @@ -18,6 +18,7 @@ from .miche_conditioner import PointConditioner from functools import partial from tqdm import tqdm from .data_utils import discretize +from comfy.utils import ProgressBar # helper functions @@ -182,6 +183,10 @@ class MeshTransformer(Module): cache = None eos_iter = None + + # ✅ Initialize ComfyUI progress bar + pbar = ProgressBar(max_seq_len - curr_length) + # predict tokens auto-regressively for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position, desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False): @@ -207,6 +212,9 @@ class MeshTransformer(Module): sample = torch.multinomial(probs, 1) codes, _ = pack([codes, sample], 'b *') + # ComfyUI progress bar + pbar.update(1) + # Check if all sequences have encountered EOS at least once is_eos_codes = (codes == self.eos_token_id) if is_eos_codes.any(dim=-1).all(): @@ -216,6 +224,9 @@ class MeshTransformer(Module): # Once we've generated 20% more tokens than eos_iter, break out of the loop if codes.shape[-1] >= int(eos_iter * 1.2): break + + # Ensure progress bar reaches 100% when loop completes + pbar.complete() # mask out to padding anything after the first eos