From 3cddb3689697ba949b689d81b3d508e68b68307d Mon Sep 17 00:00:00 2001 From: LiewFeng <1361871897@qq.com> Date: Wed, 5 Mar 2025 16:59:18 +0800 Subject: [PATCH] Support Wan2.1 --- README.md | 7 +- TeaCache4Wan2.1/README.md | 62 ++ TeaCache4Wan2.1/teacache_generate.py | 974 +++++++++++++++++++++++++++ 3 files changed, 1041 insertions(+), 2 deletions(-) create mode 100644 TeaCache4Wan2.1/README.md create mode 100644 TeaCache4Wan2.1/teacache_generate.py diff --git a/README.md b/README.md index 7a9a576..7c72bcb 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching ## πŸ”₯ Latest News - **If you like our project, please give us a star ⭐ on GitHub for the latest update.** +- [2025/03/05] πŸ”₯ Support [Wan2.1](https://github.com/Wan-Video/Wan2.1) for both T2V and I2V. - [2025/02/27] πŸŽ‰ Accepted in CVPR 2025. - [2025/01/24] πŸ”₯ Support [Cosmos](https://github.com/NVIDIA/Cosmos) for both T2V and I2V. Thanks [@zishen-ucap](https://github.com/zishen-ucap). - [2025/01/20] πŸ”₯ Support [CogVideoX1.5-5B](https://github.com/THUDM/CogVideo) for both T2V and I2V. Thanks [@zishen-ucap](https://github.com/zishen-ucap). @@ -106,6 +107,7 @@ If you develop/use TeaCache in your projects, welcome to let us know. - [TeaCache4CogVideoX1.5](./TeaCache4CogVideoX1.5/README.md) - EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate). - [TeaCache4Cosmos](./eval/TeaCache4Cosmos/README.md) +- [TeaCache4Wan2.1](./TeaCache4Wan2.1/README.md) **Image to Video** - [TeaCache4ConsisID](./TeaCache4ConsisID/README.md) @@ -113,6 +115,7 @@ If you develop/use TeaCache in your projects, welcome to let us know. - Ruyi-Models. See [here](https://github.com/IamCreateAI/Ruyi-Models). - EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate). - [TeaCache4Cosmos](./eval/TeaCache4Cosmos/README.md) +- [TeaCache4Wan2.1](./TeaCache4Wan2.1/README.md) **Video to Video** - EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate). @@ -131,12 +134,12 @@ If you develop/use TeaCache in your projects, welcome to let us know. ## πŸ’ Acknowledgement -This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux) and [Cosmos](https://github.com/NVIDIA/Cosmos). Thanks for their contributions! +This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos) and [Wan2.1](https://github.com/Wan-Video/Wan2.1). Thanks for their contributions! ## πŸ”’ License * The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file. -* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux) and [Cosmos](https://github.com/NVIDIA/Cosmos), please follow their LICENSE. +* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos) and [Wan2.1](https://github.com/Wan-Video/Wan2.1), please follow their LICENSE. * The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn) ## πŸ“– Citation diff --git a/TeaCache4Wan2.1/README.md b/TeaCache4Wan2.1/README.md new file mode 100644 index 0000000..9ef917e --- /dev/null +++ b/TeaCache4Wan2.1/README.md @@ -0,0 +1,62 @@ + +# TeaCache4Wan2.1 + +[TeaCache](https://github.com/ali-vilab/TeaCache) can speedup [Wan2.1](https://github.com/Wan-Video/Wan2.1) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Wan2.1 with various teacache_thresh values. The corresponding teacache_thresh values are shown in the following table. + +https://github.com/user-attachments/assets/5ae5d6dd-bf87-4f8f-91b8-ccc5980c56ad + +https://github.com/user-attachments/assets/dfd047a9-e3ca-4a73-a282-4dadda8dbd43 + +https://github.com/user-attachments/assets/7c20bd54-96a8-4bd7-b4fa-ea4c9da81562 + +https://github.com/user-attachments/assets/72085f45-6b78-4fae-b58f-492360a6e55e +## πŸ“ˆ Inference Latency Comparisons on a Single A800 + + +| Wan2.1 t2v 1.3B | TeaCache (0.05) | TeaCache (0.07) | TeaCache (0.08) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| ~175 s | ~117 s | ~110 s | ~88 s | + +| Wan2.1 t2v 14B | TeaCache (0.14) | TeaCache (0.15) | TeaCache (0.2) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| ~55 min | ~38 min | ~30 min | ~27 min | + +| Wan2.1 i2v 480P | TeaCache (0.13) | TeaCache (0.19) | TeaCache (0.26) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| ~735 s | ~464 s | ~372 s | ~300 s | + +| Wan2.1 i2v 720P | TeaCache (0.18) | TeaCache (0.2) | TeaCache (0.3) | +|:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| +| ~29 min | ~17 min | ~15 min | ~12 min | + +## Usage + +Follow [Wan2.1](https://github.com/Wan-Video/Wan2.1) to clone the repo and finish the installation, then copy 'teacache_generate.py' in this repo to the Wan2.1 repo. + +For T2V with 1.3B model, you can use the following command: + +```bash +python teacache_generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.08 +``` + +For T2V with 14B model, you can use the following command: + +```bash +python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.2 +``` + +For I2V with 480P resolution, you can use the following command: + +```bash +python teacache_generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.26 +``` + +For I2V with 720P resolution, you can use the following command: + +```bash +python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3 +``` + +## Acknowledgements + +We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1). \ No newline at end of file diff --git a/TeaCache4Wan2.1/teacache_generate.py b/TeaCache4Wan2.1/teacache_generate.py new file mode 100644 index 0000000..bfbaea5 --- /dev/null +++ b/TeaCache4Wan2.1/teacache_generate.py @@ -0,0 +1,974 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +from datetime import datetime +import logging +import os +import sys +import warnings + +warnings.filterwarnings('ignore') + +import torch, random +import torch.distributed as dist +from PIL import Image + +import wan +from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES +from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander +from wan.utils.utils import cache_video, cache_image, str2bool + +import gc +from contextlib import contextmanager +import torchvision.transforms.functional as TF +import torch.cuda.amp as amp +import numpy as np +import math +from wan.modules.model import sinusoidal_embedding_1d +from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from tqdm import tqdm + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2i-14B": { + "prompt": "δΈ€δΈͺζœ΄η΄ η«―εΊ„ηš„ηΎŽδΊΊ", + }, + "i2v-14B": { + "prompt": + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "image": + "examples/i2v_input.JPG", + }, +} + + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 40 if "i2v" in args.task else 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + if "i2v" in args.task and args.size in ["832*480", "480*832"]: + args.sample_shift = 3.0 + + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. + if args.frame_num is None: + args.frame_num = 1 if "t2i" in args.task else 81 + + # T2I frame_num check + if "t2i" in args.task: + assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check + assert args.size in SUPPORTED_SIZES[ + args. + task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--size", + type=str, + default="1280*720", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." + ) + parser.add_argument( + "--frame_num", + type=int, + default=None, + help="How many frames to sample from a image or video. The number should be 4n+1" + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--ulysses_size", + type=int, + default=1, + help="The size of the ulysses parallelism in DiT.") + parser.add_argument( + "--ring_size", + type=int, + default=1, + help="The size of the ring attention parallelism in DiT.") + parser.add_argument( + "--t5_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for T5.") + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--dit_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for DiT.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--prompt", + type=str, + default=None, + help="The prompt to generate the image or video from.") + parser.add_argument( + "--use_prompt_extend", + action="store_true", + default=False, + help="Whether to use prompt extend.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + parser.add_argument( + "--prompt_extend_target_lang", + type=str, + default="ch", + choices=["ch", "en"], + help="The target language of prompt extend.") + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--image", + type=str, + default=None, + help="The image to generate the video from.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--teacache_thresh", + type=float, + default=0.05, + help="The size of the ulysses parallelism in DiT.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +# add a cond_flag +def t2v_generate(self, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len, 'cond_flag': True} + arg_null = {'context': context_null, 'seq_len': seq_len, 'cond_flag': False} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + +# add a cond_flag +def i2v_generate(self, + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + + F = frame_num + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + self.vae.model.z_dim, + (F - 1) // self.vae_stride[0] + 1, + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=self.device) + + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + self.clip.model.to(self.device) + clip_context = self.clip.visual([img[:, None, :, :]]) + if offload_model: + self.clip.model.cpu() + + y = self.vae.encode([ + torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic').transpose( + 0, 1), + torch.zeros(3, F-1, h, w) + ], + dim=1).to(self.device) + ])[0] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + + arg_c = { + 'context': [context[0]], + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + 'cond_flag': True, + } + + arg_null = { + 'context': context_null, + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + 'cond_flag': False, + } + + if offload_model: + torch.cuda.empty_cache() + + self.model.to(self.device) + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0].to( + torch.device('cpu') if offload_model else self.device) + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0].to( + torch.device('cpu') if offload_model else self.device) + if offload_model: + torch.cuda.empty_cache() + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + latent = latent.to( + torch.device('cpu') if offload_model else self.device) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) + + x0 = [latent.to(self.device)] + del latent_model_input, timestep + + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latent + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + +def teacache_forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + cond_flag=False, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + if self.enable_teacache: + if cond_flag: + modulated_inp = e + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + rescale_func = np.poly1d(self.coefficients) + if cond_flag: + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1 + self.should_calc = should_calc + else: + should_calc = self.should_calc + # if not cond_flag: + # self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1 + + if self.enable_teacache: + if not should_calc: + x = x + self.previous_residual_cond if cond_flag else x + self.previous_residual_uncond + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + if cond_flag: + self.previous_residual_cond = x - ori_x + else: + self.previous_residual_uncond = x - ori_x + else: + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + else: + assert not ( + args.t5_fsdp or args.dit_fsdp + ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." + assert not ( + args.ulysses_size > 1 or args.ring_size > 1 + ), f"context parallel are not supported in non-distributed environments." + + if args.ulysses_size > 1 or args.ring_size > 1: + assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." + from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) + init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=args.ring_size, + ulysses_degree=args.ulysses_size, + ) + + if args.use_prompt_extend: + if args.prompt_extend_method == "dashscope": + prompt_expander = DashScopePromptExpander( + model_name=args.prompt_extend_model, is_vl="i2v" in args.task) + elif args.prompt_extend_method == "local_qwen": + prompt_expander = QwenPromptExpander( + model_name=args.prompt_extend_model, + is_vl="i2v" in args.task, + device=rank) + else: + raise NotImplementedError( + f"Unsupport prompt_extend_method: {args.prompt_extend_method}") + + cfg = WAN_CONFIGS[args.task] + if args.ulysses_size > 1: + assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task or "t2i" in args.task: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + logging.info(f"Input prompt: {args.prompt}") + if args.use_prompt_extend: + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanT2V pipeline.") + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + # TeaCache + wan_t2v.__class__.generate = t2v_generate + wan_t2v.model.__class__.enable_teacache = True + wan_t2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 50 + wan_t2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 2min54s, 0.05: 1min 55s(1.5x), 0.1, 1min 24s(2.1x) 0.15, 1min 6s, 0.08: 1min 27s(2x), 0.07: 1min 48s(1.6x), 0.06: 1min 51s + wan_t2v.model.__class__.accumulated_rel_l1_distance = 0 + wan_t2v.model.__class__.previous_modulated_input = None + wan_t2v.model.__class__.previous_residual = None + wan_t2v.model.__class__.previous_residual_uncond = None + wan_t2v.model.__class__.should_calc = True + if '1.3B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + if '14B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + wan_t2v.model.__class__.forward = teacache_forward + + logging.info( + f"Generating {'image' if 't2i' in args.task else 'video'} ...") + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + else: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.image is None: + args.image = EXAMPLE_PROMPT[args.task]["image"] + logging.info(f"Input prompt: {args.prompt}") + logging.info(f"Input image: {args.image}") + + img = Image.open(args.image).convert("RGB") + if args.use_prompt_extend: + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + image=img, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanI2V pipeline.") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + # TeaCache + wan_i2v.__class__.generate = i2v_generate + wan_i2v.model.__class__.enable_teacache = True + wan_i2v.model.__class__.num_steps = args.sample_steps if args.sample_steps is not None else 40 + wan_i2v.model.__class__.rel_l1_thresh = args.teacache_thresh # 12min 26s + wan_i2v.model.__class__.accumulated_rel_l1_distance = 0 + wan_i2v.model.__class__.previous_modulated_input = None + wan_i2v.model.__class__.previous_residual_cond = None + wan_i2v.model.__class__.previous_residual_uncond = None + wan_i2v.model.__class__.should_calc = True + if '480P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + if '720P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + wan_i2v.model.__class__.forward = teacache_forward + + logging.info("Generating video ...") + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", + "_")[:50] + suffix = '.png' if "t2i" in args.task else '.mp4' + args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix + + if "t2i" in args.task: + logging.info(f"Saving generated image to {args.save_file}") + cache_image( + tensor=video.squeeze(1)[None], + save_file=args.save_file, + nrow=1, + normalize=True, + value_range=(-1, 1)) + else: + logging.info(f"Saving generated video to {args.save_file}") + cache_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args)