From 485e1d69245c87dba07b5d6d2fa54451803392f5 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Wed, 25 Dec 2024 22:11:06 +0800 Subject: [PATCH] add consisid --- README.md | 11 +- TeaCache4ConsisID/README.md | 56 +++++ TeaCache4ConsisID/teacache_sample_video.py | 276 +++++++++++++++++++++ requirements.txt | 8 +- 4 files changed, 342 insertions(+), 9 deletions(-) create mode 100644 TeaCache4ConsisID/README.md create mode 100644 TeaCache4ConsisID/teacache_sample_video.py diff --git a/README.md b/README.md index 7e8ef69..5b37ac0 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ ![visualization](./assets/tisser.png) ## Latest News 🔥 +- [2024/12/25] 🔥 Support [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). - [2024/12/24] 🔥 Support [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). - [2024/12/19] 🔥 Support [CogVideoX](https://github.com/THUDM/CogVideo). - [2024/12/06] 🎉 Release the [code](https://github.com/LiewFeng/TeaCache) of TeaCache. Support [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and [Latte](https://github.com/Vchitect/Latte). @@ -65,6 +66,10 @@ We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching ## TeaCache for HunyuanVideo Please refer to [TeaCache4HunyuanVideo](./TeaCache4HunyuanVideo/README.md). +## TeaCache for ConsisID + +Please refer to [TeaCache4ConsisID](./TeaCache4ConsisID/README.md). + ## Installation Prerequisites: @@ -121,10 +126,6 @@ python vbench/cal_vbench.py --score_dir bbb python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb ``` - - - - ## Citation If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. @@ -139,4 +140,4 @@ If you find TeaCache is useful in your research or applications, please consider ## Acknowledgement -This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo) and [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). Thanks for their contributions! +This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) and [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). Thanks for their contributions! diff --git a/TeaCache4ConsisID/README.md b/TeaCache4ConsisID/README.md new file mode 100644 index 0000000..1777c10 --- /dev/null +++ b/TeaCache4ConsisID/README.md @@ -0,0 +1,56 @@ + +# TeaCache4ConsisID + +[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 2x without much visual quality degradation, in a training-free manner. + +## 📈 Inference Latency Comparisons on a Single H100 GPU + +| ConsisID | TeaCache (0.1) | TeaCache (0.15) | TeaCache (0.20) | +| :------: | :------------: | :-------------: | :-------------: | +| ~110 s | ~70 s | ~53 s | ~41 s | + + +## Usage + +Follow [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) to clone the repo and finish the installation, then you can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `ckpts_path`, `prompt`, `image` to customize your identity-preserving video. + +For single-gpu inference, you can use the following command: + +```bash +cd TeaCache4ConsisID + +python3 teacache_sample_video.py \ + --rel_l1_thresh 0.1 \ + --ckpts_path BestWishYsh/ConsisID-preview \ + --image "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ + --prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ + --seed 42 \ + --num_infer_steps 50 \ + --output_path ./teacache_results +``` + +To generate a video with 8 GPUs, you can use the following [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/parallel_inference). + +## Resources + +Learn more about ConsisID with the following resources. +- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features. +- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details. + +## Citation + +If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. + +``` +@article{liu2024timestep, + title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, + author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, + journal={arXiv preprint arXiv:2411.19108}, + year={2024} +} +``` + + +## Acknowledgements + +We would like to thank the contributors to the [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). diff --git a/TeaCache4ConsisID/teacache_sample_video.py b/TeaCache4ConsisID/teacache_sample_video.py new file mode 100644 index 0000000..a56c367 --- /dev/null +++ b/TeaCache4ConsisID/teacache_sample_video.py @@ -0,0 +1,276 @@ +import os +import argparse +import numpy as np +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from diffusers import ConsisIDPipeline +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils import export_to_video +from huggingface_hub import snapshot_download + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def teacache_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + id_cond: Optional[torch.Tensor] = None, + id_vit_hidden: Optional[torch.Tensor] = None, + return_dict: bool = True, +): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # fuse clip and insightface + if self.is_train_face: + assert id_cond is not None and id_vit_hidden is not None + id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype) + id_vit_hidden = [ + tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden + ] + valid_face_emb = self.local_facial_extractor( + id_cond, id_vit_hidden + ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048]) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90]) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072]) + hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072]) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072]) + hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072]) + + if self.enable_teacache: + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = emb + self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1 + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual + encoder_hidden_states += self.previous_residual_encoder + else: + ori_hidden_states = hidden_states.clone() + ori_encoder_hidden_states = encoder_hidden_states.clone() + # 3. Transformer blocks + ca_idx = 0 + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx]( + valid_face_emb, hidden_states + ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 + + self.previous_residual = hidden_states - ori_hidden_states + self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states + else: + # 3. Transformer blocks + ca_idx = 0 + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx]( + valid_face_emb, hidden_states + ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def main(args): + seed = args.seed + num_infer_steps = args.num_infer_steps + output_path = args.output_path + rel_l1_thresh = args.rel_l1_thresh # higher speedup will cause to worse quality -- 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup + ckpts_path = args.ckpts_path + # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + prompt = args.prompt + image = args.image + + if not os.path.exists(ckpts_path): + print("Base Model not found, downloading from Hugging Face...") + snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=ckpts_path) + else: + print(f"Base Model already exists in {ckpts_path}, skipping download.") + + if not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + + face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( + prepare_face_models(ckpts_path, device="cuda", dtype=torch.bfloat16) + ) + pipe = ConsisIDPipeline.from_pretrained(ckpts_path, torch_dtype=torch.bfloat16) + pipe.to("cuda") + + id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + face_helper_1, + face_clip_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + face_main_model, + "cuda", + torch.bfloat16, + image, + is_align_face=True, + ) + + # TeaCache Config + pipe.transformer.__class__.enable_teacache = True + pipe.transformer.__class__.cnt = 0 + pipe.transformer.__class__.num_steps = num_infer_steps - 1 + pipe.transformer.__class__.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup + pipe.transformer.__class__.accumulated_rel_l1_distance = 0 + pipe.transformer.__class__.previous_modulated_input = None + pipe.transformer.__class__.previous_residual = None + pipe.transformer.__class__.previous_residual_encoder = None + pipe.transformer.__class__.forward = teacache_forward + + video = pipe( + image=image, + prompt=prompt, + num_inference_steps=num_infer_steps, + guidance_scale=6.0, + use_dynamic_cfg=False, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + generator=torch.Generator("cuda").manual_seed(seed), + ) + file_count = len([f for f in os.listdir(output_path) if os.path.isfile(os.path.join(output_path, f))]) + video_path = f"{output_path}/{seed}_{file_count:04d}.mp4" + export_to_video(video.frames[0], video_path, fps=8) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run ConsisID with given parameters") + + parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--num_infer_steps', type=int, default=50, help='Number of inference steps') + parser.add_argument("--output_path", type=str, default="./teacache_results", help="The path where the generated video will be saved") + # higher speedup will cause to worse quality -- 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup + parser.add_argument('--rel_l1_thresh', type=float, default=0.1, help='Higher speedup will cause to worse quality -- 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup') + parser.add_argument('--ckpts_path', type=str, default="/storage/ysh/Code/ID_Consistency/Code/2_offen_codes/0_temp_hf/ConsisID/BestWishYsh/ConsisID-preview", help='Path to checkpoint') + # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + parser.add_argument('--prompt', type=str, default="The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.", help='Description of the scene for the model to interpret') + parser.add_argument('--image', type=str, default="/storage/ysh/Code/ID_Consistency/Code/2_offen_codes/ConsisID_upload/asserts/example_images/2.png", help='URL or path to input image') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3b8a98d..cc05d6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ accelerate>0.17.0 bs4 click -colossalai==0.4.0 -diffusers==0.30.0 +colossalai +diffusers einops fabric ftfy @@ -22,5 +22,5 @@ sentencepiece timm torch>=1.13 tqdm -peft==0.13.2 -transformers==4.39.3 +peft +transformers