mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-08 20:34:24 +08:00
Merge pull request #51 from zishen-ucap/teacache4Wan2.1_v2
Add --use_ret_steps Mode to Accelerate Inference and Make Generated Results Closer to Wan2.1
This commit is contained in:
commit
a9489fbf78
@ -57,6 +57,49 @@ For I2V with 720P resolution, you can use the following command:
|
||||
python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3
|
||||
```
|
||||
|
||||
## Faster Video Generation Using the `use_ret_steps` Parameter
|
||||
|
||||
Using Retention Steps will result in faster generation speed and better generation quality (except for t2v-1.3B).
|
||||
|
||||
https://github.com/user-attachments/assets/f241b5f5-1044-4223-b2a4-449dc6dc1ad7
|
||||
|
||||
https://github.com/user-attachments/assets/01db60f9-4aaf-43c4-8f1b-6e050cfa1180
|
||||
|
||||
https://github.com/user-attachments/assets/e03621f2-1085-4571-8eca-51889f47ce18
|
||||
|
||||
https://github.com/user-attachments/assets/d1340197-20c1-4f9e-a780-31f789af0893
|
||||
|
||||
|
||||
| use_ref_steps | Wan2.1 t2v 1.3B (thresh) | Slow (thresh) | Fast (thresh) |
|
||||
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||
| False | ~97 s (0.00) | ~64 s (0.05) | ~49 s (0.08) |
|
||||
| True | ~97 s (0.00) | ~61 s (0.05) | ~41 s (0.10) |
|
||||
|
||||
| use_ref_steps | Wan2.1 t2v 14B (thresh) | Slow (thresh) | Fast (thresh) |
|
||||
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||
| False | ~1829 s (0.00) | ~1234 s (0.14) | ~909 s (0.20) |
|
||||
| True | ~1829 s (0.00) | ~915 s (0.10) | ~578 s (0.20) |
|
||||
|
||||
| use_ref_steps | Wan2.1 i2v 480p (thresh) | Slow (thresh) | Fast (thresh) |
|
||||
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||
| False | ~385 s (0.00) | ~241 s (0.13) | ~156 s (0.26) |
|
||||
| True | ~385 s (0.00) | ~212 s (0.20) | ~164 s (0.30) |
|
||||
|
||||
| use_ref_steps | Wan2.1 i2v 720p (thresh) | Slow (thresh) | Fast (thresh) |
|
||||
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
|
||||
| False | ~903 s (0.00) | ~476 s (0.20) | ~363 s (0.30) |
|
||||
| True | ~903 s (0.00) | ~430 s (0.20) | ~340 s (0.30) |
|
||||
|
||||
|
||||
You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to [Wan2.1](https://github.com/Wan-Video/Wan2.1). Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings.
|
||||
|
||||
### Example Command:
|
||||
|
||||
```bash
|
||||
python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.3 --use_ret_steps
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1).
|
||||
@ -29,6 +29,7 @@ from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
"t2v-1.3B": {
|
||||
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
@ -48,183 +49,6 @@ EXAMPLE_PROMPT = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _validate_args(args):
|
||||
# Basic check
|
||||
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
||||
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
||||
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
||||
|
||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||
if args.sample_steps is None:
|
||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
||||
|
||||
if args.sample_shift is None:
|
||||
args.sample_shift = 5.0
|
||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||
args.sample_shift = 3.0
|
||||
|
||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||
if args.frame_num is None:
|
||||
args.frame_num = 1 if "t2i" in args.task else 81
|
||||
|
||||
# T2I frame_num check
|
||||
if "t2i" in args.task:
|
||||
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
|
||||
|
||||
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
||||
0, sys.maxsize)
|
||||
# Size check
|
||||
assert args.size in SUPPORTED_SIZES[
|
||||
args.
|
||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a image or video from a text prompt or image using Wan"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="t2v-14B",
|
||||
choices=list(WAN_CONFIGS.keys()),
|
||||
help="The task to run.")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="1280*720",
|
||||
choices=list(SIZE_CONFIGS.keys()),
|
||||
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_num",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many frames to sample from a image or video. The number should be 4n+1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--offload_model",
|
||||
type=str2bool,
|
||||
default=None,
|
||||
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ulysses_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size of the ulysses parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--ring_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size of the ring attention parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--t5_fsdp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for T5.")
|
||||
parser.add_argument(
|
||||
"--t5_cpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to place T5 model on CPU.")
|
||||
parser.add_argument(
|
||||
"--dit_fsdp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for DiT.")
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save the generated image or video to.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to generate the image or video from.")
|
||||
parser.add_argument(
|
||||
"--use_prompt_extend",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use prompt extend.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_method",
|
||||
type=str,
|
||||
default="local_qwen",
|
||||
choices=["dashscope", "local_qwen"],
|
||||
help="The prompt extend method to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt extend model to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_target_lang",
|
||||
type=str,
|
||||
default="ch",
|
||||
choices=["ch", "en"],
|
||||
help="The target language of prompt extend.")
|
||||
parser.add_argument(
|
||||
"--base_seed",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="The seed to use for generating the image or video.")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The image to generate the video from.")
|
||||
parser.add_argument(
|
||||
"--sample_solver",
|
||||
type=str,
|
||||
default='unipc',
|
||||
choices=['unipc', 'dpm++'],
|
||||
help="The solver used to sample.")
|
||||
parser.add_argument(
|
||||
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
||||
parser.add_argument(
|
||||
"--sample_shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Sampling shift factor for flow matching schedulers.")
|
||||
parser.add_argument(
|
||||
"--sample_guide_scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
parser.add_argument(
|
||||
"--teacache_thresh",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="The size of the ulysses parallelism in DiT.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _init_logging(rank):
|
||||
# logging
|
||||
if rank == 0:
|
||||
# set format
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||
else:
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
# add a cond_flag
|
||||
def t2v_generate(self,
|
||||
input_prompt,
|
||||
size=(1280, 720),
|
||||
@ -341,8 +165,8 @@ def t2v_generate(self,
|
||||
# sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len, 'cond_flag': True}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'cond_flag': False}
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
@ -385,7 +209,7 @@ def t2v_generate(self,
|
||||
return videos[0] if self.rank == 0 else None
|
||||
|
||||
|
||||
# add a cond_flag
|
||||
|
||||
def i2v_generate(self,
|
||||
input_prompt,
|
||||
img,
|
||||
@ -543,7 +367,7 @@ def i2v_generate(self,
|
||||
'clip_fea': clip_context,
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'cond_flag': True,
|
||||
# 'cond_flag': True,
|
||||
}
|
||||
|
||||
arg_null = {
|
||||
@ -551,7 +375,7 @@ def i2v_generate(self,
|
||||
'clip_fea': clip_context,
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'cond_flag': False,
|
||||
# 'cond_flag': False,
|
||||
}
|
||||
|
||||
if offload_model:
|
||||
@ -610,131 +434,332 @@ def i2v_generate(self,
|
||||
return videos[0] if self.rank == 0 else None
|
||||
|
||||
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
cond_flag=False,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack(
|
||||
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
# time embeddings
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack(
|
||||
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]))
|
||||
# time embeddings
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
if clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]))
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
if clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
if self.enable_teacache:
|
||||
if cond_flag:
|
||||
modulated_inp = e
|
||||
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
if cond_flag:
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
|
||||
self.should_calc = should_calc
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
if self.enable_teacache:
|
||||
modulated_inp = e0 if self.use_ref_steps else e
|
||||
# teacache
|
||||
if self.cnt%2==0: # even -> conditon
|
||||
self.is_even = True
|
||||
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||
should_calc_even = True
|
||||
self.accumulated_rel_l1_distance_even = 0
|
||||
else:
|
||||
should_calc = self.should_calc
|
||||
# if not cond_flag:
|
||||
# self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
||||
should_calc_even = False
|
||||
else:
|
||||
should_calc_even = True
|
||||
self.accumulated_rel_l1_distance_even = 0
|
||||
self.previous_e0_even = modulated_inp.clone()
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
x = x + self.previous_residual_cond if cond_flag else x + self.previous_residual_uncond
|
||||
else: # odd -> unconditon
|
||||
self.is_even = False
|
||||
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
||||
should_calc_odd = True
|
||||
self.accumulated_rel_l1_distance_odd = 0
|
||||
else:
|
||||
rescale_func = np.poly1d(self.coefficients)
|
||||
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
||||
should_calc_odd = False
|
||||
else:
|
||||
should_calc_odd = True
|
||||
self.accumulated_rel_l1_distance_odd = 0
|
||||
self.previous_e0_odd = modulated_inp.clone()
|
||||
|
||||
if self.enable_teacache:
|
||||
if self.is_even:
|
||||
if not should_calc_even:
|
||||
x += self.previous_residual_even
|
||||
else:
|
||||
ori_x = x.clone()
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
if cond_flag:
|
||||
self.previous_residual_cond = x - ori_x
|
||||
else:
|
||||
self.previous_residual_uncond = x - ori_x
|
||||
self.previous_residual_even = x - ori_x
|
||||
else:
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
if not should_calc_odd:
|
||||
x += self.previous_residual_odd
|
||||
else:
|
||||
ori_x = x.clone()
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
self.previous_residual_odd = x - ori_x
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return [u.float() for u in x]
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
self.cnt += 1
|
||||
if self.cnt >= self.num_steps:
|
||||
self.cnt = 0
|
||||
return [u.float() for u in x]
|
||||
|
||||
def _validate_args(args):
|
||||
# Basic check
|
||||
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
||||
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
||||
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
||||
|
||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||
if args.sample_steps is None:
|
||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
||||
|
||||
if args.sample_shift is None:
|
||||
args.sample_shift = 5.0
|
||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||
args.sample_shift = 3.0
|
||||
|
||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||
if args.frame_num is None:
|
||||
args.frame_num = 1 if "t2i" in args.task else 81
|
||||
|
||||
# T2I frame_num check
|
||||
if "t2i" in args.task:
|
||||
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
|
||||
|
||||
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
||||
0, sys.maxsize)
|
||||
# Size check
|
||||
assert args.size in SUPPORTED_SIZES[
|
||||
args.
|
||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a image or video from a text prompt or image using Wan"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="t2v-14B",
|
||||
choices=list(WAN_CONFIGS.keys()),
|
||||
help="The task to run.")
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default="1280*720",
|
||||
choices=list(SIZE_CONFIGS.keys()),
|
||||
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_num",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many frames to sample from a image or video. The number should be 4n+1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--offload_model",
|
||||
type=str2bool,
|
||||
default=None,
|
||||
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ulysses_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size of the ulysses parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--ring_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The size of the ring attention parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--t5_fsdp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for T5.")
|
||||
parser.add_argument(
|
||||
"--t5_cpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to place T5 model on CPU.")
|
||||
parser.add_argument(
|
||||
"--dit_fsdp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for DiT.")
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save the generated image or video to.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to generate the image or video from.")
|
||||
parser.add_argument(
|
||||
"--use_prompt_extend",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use prompt extend.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_method",
|
||||
type=str,
|
||||
default="local_qwen",
|
||||
choices=["dashscope", "local_qwen"],
|
||||
help="The prompt extend method to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt extend model to use.")
|
||||
parser.add_argument(
|
||||
"--prompt_extend_target_lang",
|
||||
type=str,
|
||||
default="ch",
|
||||
choices=["ch", "en"],
|
||||
help="The target language of prompt extend.")
|
||||
parser.add_argument(
|
||||
"--base_seed",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="The seed to use for generating the image or video.")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The image to generate the video from.")
|
||||
parser.add_argument(
|
||||
"--sample_solver",
|
||||
type=str,
|
||||
default='unipc',
|
||||
choices=['unipc', 'dpm++'],
|
||||
help="The solver used to sample.")
|
||||
parser.add_argument(
|
||||
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
||||
parser.add_argument(
|
||||
"--sample_shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Sampling shift factor for flow matching schedulers.")
|
||||
parser.add_argument(
|
||||
"--sample_guide_scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
parser.add_argument(
|
||||
"--teacache_thresh",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
|
||||
parser.add_argument(
|
||||
"--use_ret_steps",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Using Retention Steps will result in faster generation speed and better generation quality.")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _init_logging(rank):
|
||||
# logging
|
||||
if rank == 0:
|
||||
# set format
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||
else:
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
def generate(args):
|
||||
@ -841,21 +866,32 @@ def generate(args):
|
||||
|
||||
# TeaCache
|
||||
wan_t2v.__class__.generate = t2v_generate
|
||||
wan_t2v.model.__class__.cnt = 0
|
||||
wan_t2v.model.__class__.enable_teacache = True
|
||||
wan_t2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 50
|
||||
wan_t2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 2min54s, 0.05: 1min 55s(1.5x), 0.1, 1min 24s(2.1x) 0.15, 1min 6s, 0.08: 1min 27s(2x), 0.07: 1min 48s(1.6x), 0.06: 1min 51s
|
||||
wan_t2v.model.__class__.accumulated_rel_l1_distance = 0
|
||||
wan_t2v.model.__class__.previous_modulated_input = None
|
||||
wan_t2v.model.__class__.previous_residual = None
|
||||
wan_t2v.model.__class__.previous_residual_uncond = None
|
||||
wan_t2v.model.__class__.should_calc = True
|
||||
if '1.3B' in args.ckpt_dir:
|
||||
wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
||||
if '14B' in args.ckpt_dir:
|
||||
wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
||||
wan_t2v.model.__class__.forward = teacache_forward
|
||||
|
||||
wan_t2v.model.__class__.cnt = 0
|
||||
wan_t2v.model.__class__.num_steps = args.sample_steps*2
|
||||
wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh
|
||||
wan_t2v.model.__class__.accumulated_rel_l1_distance_even = 0
|
||||
wan_t2v.model.__class__.accumulated_rel_l1_distance_odd = 0
|
||||
wan_t2v.model.__class__.previous_e0_even = None
|
||||
wan_t2v.model.__class__.previous_e0_odd = None
|
||||
wan_t2v.model.__class__.previous_residual_even = None
|
||||
wan_t2v.model.__class__.previous_residual_odd = None
|
||||
wan_t2v.model.__class__.use_ref_steps = args.use_ret_steps
|
||||
if args.use_ret_steps:
|
||||
if '1.3B' in args.ckpt_dir:
|
||||
wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
||||
if '14B' in args.ckpt_dir:
|
||||
wan_t2v.model.__class__.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
||||
wan_t2v.model.__class__.ret_steps = 5*2
|
||||
wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2
|
||||
else:
|
||||
if '1.3B' in args.ckpt_dir:
|
||||
wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
||||
if '14B' in args.ckpt_dir:
|
||||
wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
||||
wan_t2v.model.__class__.ret_steps = 1*2
|
||||
wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
|
||||
logging.info(
|
||||
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
video = wan_t2v.generate(
|
||||
@ -912,23 +948,34 @@ def generate(args):
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
# TeaCache
|
||||
wan_i2v.__class__.generate = i2v_generate
|
||||
wan_i2v.model.__class__.cnt = 0
|
||||
wan_i2v.model.__class__.enable_teacache = True
|
||||
wan_i2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 40
|
||||
wan_i2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 12min 26s
|
||||
wan_i2v.model.__class__.accumulated_rel_l1_distance = 0
|
||||
wan_i2v.model.__class__.previous_modulated_input = None
|
||||
wan_i2v.model.__class__.previous_residual_cond = None
|
||||
wan_i2v.model.__class__.previous_residual_uncond = None
|
||||
wan_i2v.model.__class__.should_calc = True
|
||||
if '480P' in args.ckpt_dir:
|
||||
wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||
if '720P' in args.ckpt_dir:
|
||||
wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||
wan_i2v.model.__class__.forward = teacache_forward
|
||||
wan_i2v.model.__class__.cnt = 0
|
||||
wan_i2v.model.__class__.num_steps = args.sample_steps*2
|
||||
wan_i2v.model.__class__.teacache_thresh = args.teacache_thresh
|
||||
wan_i2v.model.__class__.accumulated_rel_l1_distance_even = 0
|
||||
wan_i2v.model.__class__.accumulated_rel_l1_distance_odd = 0
|
||||
wan_i2v.model.__class__.previous_e0_even = None
|
||||
wan_i2v.model.__class__.previous_e0_odd = None
|
||||
wan_i2v.model.__class__.previous_residual_even = None
|
||||
wan_i2v.model.__class__.previous_residual_odd = None
|
||||
wan_i2v.model.__class__.use_ref_steps = args.use_ret_steps
|
||||
if args.use_ret_steps:
|
||||
if '480P' in args.ckpt_dir:
|
||||
wan_i2v.model.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
||||
if '720P' in args.ckpt_dir:
|
||||
wan_i2v.model.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
||||
wan_i2v.model.__class__.ret_steps = 5*2
|
||||
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2
|
||||
else:
|
||||
if '480P' in args.ckpt_dir:
|
||||
wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||
if '720P' in args.ckpt_dir:
|
||||
wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||
wan_i2v.model.__class__.ret_steps = 1*2
|
||||
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_i2v.generate(
|
||||
@ -971,6 +1018,7 @@ def generate(args):
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_args()
|
||||
generate(args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user