Compare commits

...

2 Commits

Author SHA1 Message Date
Feng Liu
1ba2be0a9b
Update README.md 2025-03-05 18:59:44 +08:00
LiewFeng
3cddb36896 Support Wan2.1 2025-03-05 16:59:18 +08:00
3 changed files with 1044 additions and 2 deletions

View File

@ -64,6 +64,7 @@ We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching
## 🔥 Latest News
- **If you like our project, please give us a star ⭐ on GitHub for the latest update.**
- [2025/03/05] 🔥 Support [Wan2.1](https://github.com/Wan-Video/Wan2.1) for both T2V and I2V.
- [2025/02/27] 🎉 Accepted in CVPR 2025.
- [2025/01/24] 🔥 Support [Cosmos](https://github.com/NVIDIA/Cosmos) for both T2V and I2V. Thanks [@zishen-ucap](https://github.com/zishen-ucap).
- [2025/01/20] 🔥 Support [CogVideoX1.5-5B](https://github.com/THUDM/CogVideo) for both T2V and I2V. Thanks [@zishen-ucap](https://github.com/zishen-ucap).
@ -92,6 +93,9 @@ If you develop/use TeaCache in your projects, welcome to let us know.
- [ComfyUI_Patches_ll](https://github.com/lldacing/ComfyUI_Patches_ll) supports TeaCache. Thanks [@lldacing](https://github.com/lldacing).
- [ComfyUI-TangoFlux](https://github.com/LucipherDev/ComfyUI-TangoFlux) supports TeaCache. Thanks [@LucipherDev](https://github.com/LucipherDev).
**Parallelism**
- [Teacache-xDiT](https://github.com/MingXiangL/Teacache-xDiT) for multi-gpu inference. Thanks [@MingXiangL](https://github.com/MingXiangL).
## 🎉 Supported Models
@ -106,6 +110,7 @@ If you develop/use TeaCache in your projects, welcome to let us know.
- [TeaCache4CogVideoX1.5](./TeaCache4CogVideoX1.5/README.md)
- EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate).
- [TeaCache4Cosmos](./eval/TeaCache4Cosmos/README.md)
- [TeaCache4Wan2.1](./TeaCache4Wan2.1/README.md)
**Image to Video**
- [TeaCache4ConsisID](./TeaCache4ConsisID/README.md)
@ -113,6 +118,7 @@ If you develop/use TeaCache in your projects, welcome to let us know.
- Ruyi-Models. See [here](https://github.com/IamCreateAI/Ruyi-Models).
- EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate).
- [TeaCache4Cosmos](./eval/TeaCache4Cosmos/README.md)
- [TeaCache4Wan2.1](./TeaCache4Wan2.1/README.md)
**Video to Video**
- EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate).
@ -131,12 +137,12 @@ If you develop/use TeaCache in your projects, welcome to let us know.
## 💐 Acknowledgement
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux) and [Cosmos](https://github.com/NVIDIA/Cosmos). Thanks for their contributions!
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos) and [Wan2.1](https://github.com/Wan-Video/Wan2.1). Thanks for their contributions!
## 🔒 License
* The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux) and [Cosmos](https://github.com/NVIDIA/Cosmos), please follow their LICENSE.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos) and [Wan2.1](https://github.com/Wan-Video/Wan2.1), please follow their LICENSE.
* The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn)
## 📖 Citation

62
TeaCache4Wan2.1/README.md Normal file
View File

@ -0,0 +1,62 @@
<!-- ## **TeaCache4Wan2.1** -->
# TeaCache4Wan2.1
[TeaCache](https://github.com/ali-vilab/TeaCache) can speedup [Wan2.1](https://github.com/Wan-Video/Wan2.1) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Wan2.1 with various teacache_thresh values. The corresponding teacache_thresh values are shown in the following table.
https://github.com/user-attachments/assets/5ae5d6dd-bf87-4f8f-91b8-ccc5980c56ad
https://github.com/user-attachments/assets/dfd047a9-e3ca-4a73-a282-4dadda8dbd43
https://github.com/user-attachments/assets/7c20bd54-96a8-4bd7-b4fa-ea4c9da81562
https://github.com/user-attachments/assets/72085f45-6b78-4fae-b58f-492360a6e55e
## 📈 Inference Latency Comparisons on a Single A800
| Wan2.1 t2v 1.3B | TeaCache (0.05) | TeaCache (0.07) | TeaCache (0.08) |
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
| ~175 s | ~117 s | ~110 s | ~88 s |
| Wan2.1 t2v 14B | TeaCache (0.14) | TeaCache (0.15) | TeaCache (0.2) |
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
| ~55 min | ~38 min | ~30 min | ~27 min |
| Wan2.1 i2v 480P | TeaCache (0.13) | TeaCache (0.19) | TeaCache (0.26) |
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
| ~735 s | ~464 s | ~372 s | ~300 s |
| Wan2.1 i2v 720P | TeaCache (0.18) | TeaCache (0.2) | TeaCache (0.3) |
|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
| ~29 min | ~17 min | ~15 min | ~12 min |
## Usage
Follow [Wan2.1](https://github.com/Wan-Video/Wan2.1) to clone the repo and finish the installation, then copy 'teacache_generate.py' in this repo to the Wan2.1 repo.
For T2V with 1.3B model, you can use the following command:
```bash
python teacache_generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.08
```
For T2V with 14B model, you can use the following command:
```bash
python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.2
```
For I2V with 480P resolution, you can use the following command:
```bash
python teacache_generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.26
```
For I2V with 720P resolution, you can use the following command:
```bash
python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3
```
## Acknowledgements
We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1).

View File

@ -0,0 +1,974 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
warnings.filterwarnings('ignore')
import torch, random
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
import gc
from contextlib import contextmanager
import torchvision.transforms.functional as TF
import torch.cuda.amp as amp
import numpy as np
import math
from wan.modules.model import sinusoidal_embedding_1d
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from tqdm import tqdm
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
},
"i2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
}
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
# T2I frame_num check
if "t2i" in args.task:
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
# Size check
assert args.size in SUPPORTED_SIZES[
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="t2v-14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=None,
help="How many frames to sample from a image or video. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5.")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU.")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the image or video from.")
parser.add_argument(
"--use_prompt_extend",
action="store_true",
default=False,
help="Whether to use prompt extend.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
parser.add_argument(
"--prompt_extend_target_lang",
type=str,
default="ch",
choices=["ch", "en"],
help="The target language of prompt extend.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.05,
help="The size of the ulysses parallelism in DiT.")
args = parser.parse_args()
_validate_args(args)
return args
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
# add a cond_flag
def t2v_generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len, 'cond_flag': True}
arg_null = {'context': context_null, 'seq_len': seq_len, 'cond_flag': False}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
# add a cond_flag
def i2v_generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
self.vae.model.z_dim,
(F - 1) // self.vae_stride[0] + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, F-1, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'cond_flag': True,
}
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'cond_flag': False,
}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
def teacache_forward(
self,
x,
t,
context,
seq_len,
clip_fea=None,
y=None,
cond_flag=False,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
if self.enable_teacache:
if cond_flag:
modulated_inp = e
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
if cond_flag:
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
self.should_calc = should_calc
else:
should_calc = self.should_calc
# if not cond_flag:
# self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1
if self.enable_teacache:
if not should_calc:
x = x + self.previous_residual_cond if cond_flag else x + self.previous_residual_uncond
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
if cond_flag:
self.previous_residual_cond = x - ori_x
else:
self.previous_residual_uncond = x - ori_x
else:
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_size,
ulysses_degree=args.ulysses_size,
)
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
# TeaCache
wan_t2v.__class__.generate = t2v_generate
wan_t2v.model.__class__.enable_teacache = True
wan_t2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 50
wan_t2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 2min54s, 0.05: 1min 55s(1.5x), 0.1, 1min 24s(2.1x) 0.15, 1min 6s, 0.08: 1min 27s(2x), 0.07: 1min 48s(1.6x), 0.06: 1min 51s
wan_t2v.model.__class__.accumulated_rel_l1_distance = 0
wan_t2v.model.__class__.previous_modulated_input = None
wan_t2v.model.__class__.previous_residual = None
wan_t2v.model.__class__.previous_residual_uncond = None
wan_t2v.model.__class__.should_calc = True
if '1.3B' in args.ckpt_dir:
wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
if '14B' in args.ckpt_dir:
wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
wan_t2v.model.__class__.forward = teacache_forward
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
args.image = EXAMPLE_PROMPT[args.task]["image"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input image: {args.image}")
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=img,
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
# TeaCache
wan_i2v.__class__.generate = i2v_generate
wan_i2v.model.__class__.enable_teacache = True
wan_i2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 40
wan_i2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 12min 26s
wan_i2v.model.__class__.accumulated_rel_l1_distance = 0
wan_i2v.model.__class__.previous_modulated_input = None
wan_i2v.model.__class__.previous_residual_cond = None
wan_i2v.model.__class__.previous_residual_uncond = None
wan_i2v.model.__class__.should_calc = True
if '480P' in args.ckpt_dir:
wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
if '720P' in args.ckpt_dir:
wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
wan_i2v.model.__class__.forward = teacache_forward
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
cache_image(
tensor=video.squeeze(1)[None],
save_file=args.save_file,
nrow=1,
normalize=True,
value_range=(-1, 1))
else:
logging.info(f"Saving generated video to {args.save_file}")
cache_video(
tensor=video[None],
save_file=args.save_file,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)