This commit is contained in:
LiewFeng 2024-12-27 10:55:22 +08:00
parent d74927d337
commit 202ae9fdfe
5 changed files with 12 additions and 9 deletions

View File

@ -85,7 +85,7 @@ conda create -n teacache python=3.10 -y
conda activate teacache
```
Install VideoSys:
Install TeaCache:
```shell
git clone https://github.com/LiewFeng/TeaCache

View File

@ -200,7 +200,7 @@ def eval_teacache_slow(prompt_list):
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.previous_residual_encoder = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
@ -212,7 +212,7 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.previous_residual_encoder = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5)

View File

@ -6,6 +6,9 @@ from torch import nn
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.models.transformers.latte_transformer_3d import Transformer3DModelOutput
from videosys.utils.utils import batch_func
from functools import partial
def teacache_forward(
self,
@ -502,7 +505,7 @@ def eval_teacache_slow(prompt_list):
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
@ -513,7 +516,7 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)

View File

@ -216,7 +216,7 @@ def eval_teacache_slow(prompt_list):
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
@ -227,7 +227,7 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)

View File

@ -565,7 +565,7 @@ def eval_teacache_slow(prompt_list):
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
@ -576,7 +576,7 @@ def eval_teacache_fast(prompt_list):
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.__class__.forward = teacache_forward
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)