fix small bug

This commit is contained in:
LiewFeng 2024-12-27 11:06:52 +08:00
parent 202ae9fdfe
commit 27ecce2b3a
2 changed files with 8 additions and 4 deletions

View File

@ -88,7 +88,9 @@ def teacache_forward(
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = emb
self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
self.cnt += 1
if self.cnt == self.num_steps-1:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
@ -236,7 +238,7 @@ def main(args):
# TeaCache Config
pipe.transformer.__class__.enable_teacache = True
pipe.transformer.__class__.cnt = 0
pipe.transformer.__class__.num_steps = num_infer_steps - 1
pipe.transformer.__class__.num_steps = num_infer_steps
pipe.transformer.__class__.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup
pipe.transformer.__class__.accumulated_rel_l1_distance = 0
pipe.transformer.__class__.previous_modulated_input = None

View File

@ -108,7 +108,9 @@ def teacache_forward(
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
@ -216,7 +218,7 @@ def main():
# TeaCache
hunyuan_video_sampler.pipeline.transformer.__class__.enable_teacache = True
hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0
hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps - 1
hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps
hunyuan_video_sampler.pipeline.transformer.__class__.rel_l1_thresh = 0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
hunyuan_video_sampler.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
hunyuan_video_sampler.pipeline.transformer.__class__.previous_modulated_input = None