mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 04:44:23 +08:00
Merge pull request #51 from zishen-ucap/teacache4Wan2.1_v2
Add --use_ret_steps Mode to Accelerate Inference and Make Generated Results Closer to Wan2.1
This commit is contained in:
commit
a9489fbf78
@ -57,6 +57,49 @@ For I2V with 720P resolution, you can use the following command:
|
|||||||
python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3
|
python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Faster Video Generation Using the `use_ret_steps` Parameter
|
||||||
|
|
||||||
|
Using Retention Steps will result in faster generation speed and better generation quality (except for t2v-1.3B).
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/f241b5f5-1044-4223-b2a4-449dc6dc1ad7
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/01db60f9-4aaf-43c4-8f1b-6e050cfa1180
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/e03621f2-1085-4571-8eca-51889f47ce18
|
||||||
|
|
||||||
|
https://github.com/user-attachments/assets/d1340197-20c1-4f9e-a780-31f789af0893
|
||||||
|
|
||||||
|
|
||||||
|
| use_ref_steps | Wan2.1 t2v 1.3B (thresh) | Slow (thresh) | Fast (thresh) |
|
||||||
|
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||||
|
| False | ~97 s (0.00) | ~64 s (0.05) | ~49 s (0.08) |
|
||||||
|
| True | ~97 s (0.00) | ~61 s (0.05) | ~41 s (0.10) |
|
||||||
|
|
||||||
|
| use_ref_steps | Wan2.1 t2v 14B (thresh) | Slow (thresh) | Fast (thresh) |
|
||||||
|
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||||
|
| False | ~1829 s (0.00) | ~1234 s (0.14) | ~909 s (0.20) |
|
||||||
|
| True | ~1829 s (0.00) | ~915 s (0.10) | ~578 s (0.20) |
|
||||||
|
|
||||||
|
| use_ref_steps | Wan2.1 i2v 480p (thresh) | Slow (thresh) | Fast (thresh) |
|
||||||
|
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||||
|
| False | ~385 s (0.00) | ~241 s (0.13) | ~156 s (0.26) |
|
||||||
|
| True | ~385 s (0.00) | ~212 s (0.20) | ~164 s (0.30) |
|
||||||
|
|
||||||
|
| use_ref_steps | Wan2.1 i2v 720p (thresh) | Slow (thresh) | Fast (thresh) |
|
||||||
|
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||||
|
| False | ~903 s (0.00) | ~476 s (0.20) | ~363 s (0.30) |
|
||||||
|
| True | ~903 s (0.00) | ~430 s (0.20) | ~340 s (0.30) |
|
||||||
|
|
||||||
|
|
||||||
|
You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to [Wan2.1](https://github.com/Wan-Video/Wan2.1). Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings.
|
||||||
|
|
||||||
|
### Example Command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.3 --use_ret_steps
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
|
|
||||||
We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1).
|
We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1).
|
||||||
@ -29,6 +29,7 @@ from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|||||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
"t2v-1.3B": {
|
"t2v-1.3B": {
|
||||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
@ -48,183 +49,6 @@ EXAMPLE_PROMPT = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_args(args):
|
|
||||||
# Basic check
|
|
||||||
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
|
||||||
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
|
||||||
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
|
||||||
|
|
||||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
|
||||||
if args.sample_steps is None:
|
|
||||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
|
||||||
|
|
||||||
if args.sample_shift is None:
|
|
||||||
args.sample_shift = 5.0
|
|
||||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
|
||||||
args.sample_shift = 3.0
|
|
||||||
|
|
||||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
|
||||||
if args.frame_num is None:
|
|
||||||
args.frame_num = 1 if "t2i" in args.task else 81
|
|
||||||
|
|
||||||
# T2I frame_num check
|
|
||||||
if "t2i" in args.task:
|
|
||||||
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
|
|
||||||
|
|
||||||
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
|
||||||
0, sys.maxsize)
|
|
||||||
# Size check
|
|
||||||
assert args.size in SUPPORTED_SIZES[
|
|
||||||
args.
|
|
||||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate a image or video from a text prompt or image using Wan"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--task",
|
|
||||||
type=str,
|
|
||||||
default="t2v-14B",
|
|
||||||
choices=list(WAN_CONFIGS.keys()),
|
|
||||||
help="The task to run.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--size",
|
|
||||||
type=str,
|
|
||||||
default="1280*720",
|
|
||||||
choices=list(SIZE_CONFIGS.keys()),
|
|
||||||
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--frame_num",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="How many frames to sample from a image or video. The number should be 4n+1"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The path to the checkpoint directory.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--offload_model",
|
|
||||||
type=str2bool,
|
|
||||||
default=None,
|
|
||||||
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ulysses_size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The size of the ulysses parallelism in DiT.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ring_size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The size of the ring attention parallelism in DiT.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--t5_fsdp",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether to use FSDP for T5.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--t5_cpu",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether to place T5 model on CPU.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dit_fsdp",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether to use FSDP for DiT.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The file to save the generated image or video to.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The prompt to generate the image or video from.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_prompt_extend",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Whether to use prompt extend.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_method",
|
|
||||||
type=str,
|
|
||||||
default="local_qwen",
|
|
||||||
choices=["dashscope", "local_qwen"],
|
|
||||||
help="The prompt extend method to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The prompt extend model to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt_extend_target_lang",
|
|
||||||
type=str,
|
|
||||||
default="ch",
|
|
||||||
choices=["ch", "en"],
|
|
||||||
help="The target language of prompt extend.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--base_seed",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="The seed to use for generating the image or video.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--image",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The image to generate the video from.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_solver",
|
|
||||||
type=str,
|
|
||||||
default='unipc',
|
|
||||||
choices=['unipc', 'dpm++'],
|
|
||||||
help="The solver used to sample.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_shift",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Sampling shift factor for flow matching schedulers.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample_guide_scale",
|
|
||||||
type=float,
|
|
||||||
default=5.0,
|
|
||||||
help="Classifier free guidance scale.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--teacache_thresh",
|
|
||||||
type=float,
|
|
||||||
default=0.05,
|
|
||||||
help="The size of the ulysses parallelism in DiT.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
_validate_args(args)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def _init_logging(rank):
|
|
||||||
# logging
|
|
||||||
if rank == 0:
|
|
||||||
# set format
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
|
||||||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
|
||||||
else:
|
|
||||||
logging.basicConfig(level=logging.ERROR)
|
|
||||||
|
|
||||||
|
|
||||||
# add a cond_flag
|
|
||||||
def t2v_generate(self,
|
def t2v_generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
size=(1280, 720),
|
size=(1280, 720),
|
||||||
@ -341,8 +165,8 @@ def t2v_generate(self,
|
|||||||
# sample videos
|
# sample videos
|
||||||
latents = noise
|
latents = noise
|
||||||
|
|
||||||
arg_c = {'context': context, 'seq_len': seq_len, 'cond_flag': True}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'cond_flag': False}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
@ -385,7 +209,7 @@ def t2v_generate(self,
|
|||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|
||||||
|
|
||||||
# add a cond_flag
|
|
||||||
def i2v_generate(self,
|
def i2v_generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
img,
|
img,
|
||||||
@ -543,7 +367,7 @@ def i2v_generate(self,
|
|||||||
'clip_fea': clip_context,
|
'clip_fea': clip_context,
|
||||||
'seq_len': max_seq_len,
|
'seq_len': max_seq_len,
|
||||||
'y': [y],
|
'y': [y],
|
||||||
'cond_flag': True,
|
# 'cond_flag': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
arg_null = {
|
arg_null = {
|
||||||
@ -551,7 +375,7 @@ def i2v_generate(self,
|
|||||||
'clip_fea': clip_context,
|
'clip_fea': clip_context,
|
||||||
'seq_len': max_seq_len,
|
'seq_len': max_seq_len,
|
||||||
'y': [y],
|
'y': [y],
|
||||||
'cond_flag': False,
|
# 'cond_flag': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
@ -610,6 +434,7 @@ def i2v_generate(self,
|
|||||||
return videos[0] if self.rank == 0 else None
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def teacache_forward(
|
def teacache_forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -618,7 +443,6 @@ def teacache_forward(
|
|||||||
seq_len,
|
seq_len,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
y=None,
|
y=None,
|
||||||
cond_flag=False,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -662,6 +486,7 @@ def teacache_forward(
|
|||||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
dim=1) for u in x
|
dim=1) for u in x
|
||||||
])
|
])
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
with amp.autocast(dtype=torch.float32):
|
with amp.autocast(dtype=torch.float32):
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
@ -692,39 +517,56 @@ def teacache_forward(
|
|||||||
context_lens=context_lens)
|
context_lens=context_lens)
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if cond_flag:
|
modulated_inp = e0 if self.use_ref_steps else e
|
||||||
modulated_inp = e
|
# teacache
|
||||||
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
if self.cnt%2==0: # even -> conditon
|
||||||
should_calc = True
|
self.is_even = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_even = True
|
||||||
|
self.accumulated_rel_l1_distance_even = 0
|
||||||
else:
|
else:
|
||||||
rescale_func = np.poly1d(self.coefficients)
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
if cond_flag:
|
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
||||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
should_calc_even = False
|
||||||
should_calc = False
|
|
||||||
else:
|
else:
|
||||||
should_calc = True
|
should_calc_even = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance_even = 0
|
||||||
self.previous_modulated_input = modulated_inp
|
self.previous_e0_even = modulated_inp.clone()
|
||||||
self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
|
|
||||||
self.should_calc = should_calc
|
else: # odd -> unconditon
|
||||||
|
self.is_even = False
|
||||||
|
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||||
|
should_calc_odd = True
|
||||||
|
self.accumulated_rel_l1_distance_odd = 0
|
||||||
else:
|
else:
|
||||||
should_calc = self.should_calc
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
# if not cond_flag:
|
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
|
||||||
# self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
|
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
||||||
|
should_calc_odd = False
|
||||||
|
else:
|
||||||
|
should_calc_odd = True
|
||||||
|
self.accumulated_rel_l1_distance_odd = 0
|
||||||
|
self.previous_e0_odd = modulated_inp.clone()
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if not should_calc:
|
if self.is_even:
|
||||||
x = x + self.previous_residual_cond if cond_flag else x + self.previous_residual_uncond
|
if not should_calc_even:
|
||||||
|
x += self.previous_residual_even
|
||||||
else:
|
else:
|
||||||
ori_x = x.clone()
|
ori_x = x.clone()
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, **kwargs)
|
x = block(x, **kwargs)
|
||||||
if cond_flag:
|
self.previous_residual_even = x - ori_x
|
||||||
self.previous_residual_cond = x - ori_x
|
|
||||||
else:
|
else:
|
||||||
self.previous_residual_uncond = x - ori_x
|
if not should_calc_odd:
|
||||||
|
x += self.previous_residual_odd
|
||||||
|
else:
|
||||||
|
ori_x = x.clone()
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
self.previous_residual_odd = x - ori_x
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, **kwargs)
|
x = block(x, **kwargs)
|
||||||
@ -734,8 +576,191 @@ def teacache_forward(
|
|||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
self.cnt += 1
|
||||||
|
if self.cnt >= self.num_steps:
|
||||||
|
self.cnt = 0
|
||||||
return [u.float() for u in x]
|
return [u.float() for u in x]
|
||||||
|
|
||||||
|
def _validate_args(args):
|
||||||
|
# Basic check
|
||||||
|
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
||||||
|
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
||||||
|
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
||||||
|
|
||||||
|
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||||
|
if args.sample_steps is None:
|
||||||
|
args.sample_steps = 40 if "i2v" in args.task else 50
|
||||||
|
|
||||||
|
if args.sample_shift is None:
|
||||||
|
args.sample_shift = 5.0
|
||||||
|
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||||
|
args.sample_shift = 3.0
|
||||||
|
|
||||||
|
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||||
|
if args.frame_num is None:
|
||||||
|
args.frame_num = 1 if "t2i" in args.task else 81
|
||||||
|
|
||||||
|
# T2I frame_num check
|
||||||
|
if "t2i" in args.task:
|
||||||
|
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
|
||||||
|
|
||||||
|
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
||||||
|
0, sys.maxsize)
|
||||||
|
# Size check
|
||||||
|
assert args.size in SUPPORTED_SIZES[
|
||||||
|
args.
|
||||||
|
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate a image or video from a text prompt or image using Wan"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default="t2v-14B",
|
||||||
|
choices=list(WAN_CONFIGS.keys()),
|
||||||
|
help="The task to run.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--size",
|
||||||
|
type=str,
|
||||||
|
default="1280*720",
|
||||||
|
choices=list(SIZE_CONFIGS.keys()),
|
||||||
|
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame_num",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="How many frames to sample from a image or video. The number should be 4n+1"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path to the checkpoint directory.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_model",
|
||||||
|
type=str2bool,
|
||||||
|
default=None,
|
||||||
|
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ulysses_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The size of the ulysses parallelism in DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ring_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The size of the ring attention parallelism in DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--t5_fsdp",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to use FSDP for T5.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--t5_cpu",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to place T5 model on CPU.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dit_fsdp",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to use FSDP for DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file to save the generated image or video to.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The prompt to generate the image or video from.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_prompt_extend",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to use prompt extend.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_extend_method",
|
||||||
|
type=str,
|
||||||
|
default="local_qwen",
|
||||||
|
choices=["dashscope", "local_qwen"],
|
||||||
|
help="The prompt extend method to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_extend_model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The prompt extend model to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_extend_target_lang",
|
||||||
|
type=str,
|
||||||
|
default="ch",
|
||||||
|
choices=["ch", "en"],
|
||||||
|
help="The target language of prompt extend.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_seed",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="The seed to use for generating the image or video.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--image",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The image to generate the video from.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_solver",
|
||||||
|
type=str,
|
||||||
|
default='unipc',
|
||||||
|
choices=['unipc', 'dpm++'],
|
||||||
|
help="The solver used to sample.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_shift",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Sampling shift factor for flow matching schedulers.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_guide_scale",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="Classifier free guidance scale.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--teacache_thresh",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_ret_steps",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Using Retention Steps will result in faster generation speed and better generation quality.")
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
_validate_args(args)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def _init_logging(rank):
|
||||||
|
# logging
|
||||||
|
if rank == 0:
|
||||||
|
# set format
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def generate(args):
|
def generate(args):
|
||||||
rank = int(os.getenv("RANK", 0))
|
rank = int(os.getenv("RANK", 0))
|
||||||
@ -841,21 +866,32 @@ def generate(args):
|
|||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
wan_t2v.__class__.generate = t2v_generate
|
wan_t2v.__class__.generate = t2v_generate
|
||||||
wan_t2v.model.__class__.cnt = 0
|
|
||||||
wan_t2v.model.__class__.enable_teacache = True
|
wan_t2v.model.__class__.enable_teacache = True
|
||||||
wan_t2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 50
|
wan_t2v.model.__class__.forward = teacache_forward
|
||||||
wan_t2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 2min54s, 0.05: 1min 55s(1.5x), 0.1, 1min 24s(2.1x) 0.15, 1min 6s, 0.08: 1min 27s(2x), 0.07: 1min 48s(1.6x), 0.06: 1min 51s
|
wan_t2v.model.__class__.cnt = 0
|
||||||
wan_t2v.model.__class__.accumulated_rel_l1_distance = 0
|
wan_t2v.model.__class__.num_steps = args.sample_steps*2
|
||||||
wan_t2v.model.__class__.previous_modulated_input = None
|
wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh
|
||||||
wan_t2v.model.__class__.previous_residual = None
|
wan_t2v.model.__class__.accumulated_rel_l1_distance_even = 0
|
||||||
wan_t2v.model.__class__.previous_residual_uncond = None
|
wan_t2v.model.__class__.accumulated_rel_l1_distance_odd = 0
|
||||||
wan_t2v.model.__class__.should_calc = True
|
wan_t2v.model.__class__.previous_e0_even = None
|
||||||
|
wan_t2v.model.__class__.previous_e0_odd = None
|
||||||
|
wan_t2v.model.__class__.previous_residual_even = None
|
||||||
|
wan_t2v.model.__class__.previous_residual_odd = None
|
||||||
|
wan_t2v.model.__class__.use_ref_steps = args.use_ret_steps
|
||||||
|
if args.use_ret_steps:
|
||||||
|
if '1.3B' in args.ckpt_dir:
|
||||||
|
wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
||||||
|
if '14B' in args.ckpt_dir:
|
||||||
|
wan_t2v.model.__class__.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
||||||
|
wan_t2v.model.__class__.ret_steps = 5*2
|
||||||
|
wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2
|
||||||
|
else:
|
||||||
if '1.3B' in args.ckpt_dir:
|
if '1.3B' in args.ckpt_dir:
|
||||||
wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
||||||
if '14B' in args.ckpt_dir:
|
if '14B' in args.ckpt_dir:
|
||||||
wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
||||||
wan_t2v.model.__class__.forward = teacache_forward
|
wan_t2v.model.__class__.ret_steps = 1*2
|
||||||
|
wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||||
video = wan_t2v.generate(
|
video = wan_t2v.generate(
|
||||||
@ -912,23 +948,34 @@ def generate(args):
|
|||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TeaCache
|
# TeaCache
|
||||||
wan_i2v.__class__.generate = i2v_generate
|
wan_i2v.__class__.generate = i2v_generate
|
||||||
wan_i2v.model.__class__.cnt = 0
|
|
||||||
wan_i2v.model.__class__.enable_teacache = True
|
wan_i2v.model.__class__.enable_teacache = True
|
||||||
wan_i2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 40
|
wan_i2v.model.__class__.forward = teacache_forward
|
||||||
wan_i2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 12min 26s
|
wan_i2v.model.__class__.cnt = 0
|
||||||
wan_i2v.model.__class__.accumulated_rel_l1_distance = 0
|
wan_i2v.model.__class__.num_steps = args.sample_steps*2
|
||||||
wan_i2v.model.__class__.previous_modulated_input = None
|
wan_i2v.model.__class__.teacache_thresh = args.teacache_thresh
|
||||||
wan_i2v.model.__class__.previous_residual_cond = None
|
wan_i2v.model.__class__.accumulated_rel_l1_distance_even = 0
|
||||||
wan_i2v.model.__class__.previous_residual_uncond = None
|
wan_i2v.model.__class__.accumulated_rel_l1_distance_odd = 0
|
||||||
wan_i2v.model.__class__.should_calc = True
|
wan_i2v.model.__class__.previous_e0_even = None
|
||||||
|
wan_i2v.model.__class__.previous_e0_odd = None
|
||||||
|
wan_i2v.model.__class__.previous_residual_even = None
|
||||||
|
wan_i2v.model.__class__.previous_residual_odd = None
|
||||||
|
wan_i2v.model.__class__.use_ref_steps = args.use_ret_steps
|
||||||
|
if args.use_ret_steps:
|
||||||
|
if '480P' in args.ckpt_dir:
|
||||||
|
wan_i2v.model.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
||||||
|
if '720P' in args.ckpt_dir:
|
||||||
|
wan_i2v.model.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
||||||
|
wan_i2v.model.__class__.ret_steps = 5*2
|
||||||
|
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2
|
||||||
|
else:
|
||||||
if '480P' in args.ckpt_dir:
|
if '480P' in args.ckpt_dir:
|
||||||
wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||||
if '720P' in args.ckpt_dir:
|
if '720P' in args.ckpt_dir:
|
||||||
wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||||
wan_i2v.model.__class__.forward = teacache_forward
|
wan_i2v.model.__class__.ret_steps = 1*2
|
||||||
|
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
video = wan_i2v.generate(
|
video = wan_i2v.generate(
|
||||||
@ -971,6 +1018,7 @@ def generate(args):
|
|||||||
logging.info("Finished.")
|
logging.info("Finished.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
generate(args)
|
generate(args)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user