214 lines
7.6 KiB
Python
214 lines
7.6 KiB
Python
import json
|
||
import os
|
||
import tempfile
|
||
import time
|
||
|
||
import click
|
||
import numpy as np
|
||
#import ray
|
||
from einops import rearrange
|
||
from PIL import Image
|
||
from tqdm import tqdm
|
||
|
||
from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
||
|
||
model = None
|
||
model_path = "weights"
|
||
def noexcept(f):
|
||
try:
|
||
return f()
|
||
except:
|
||
pass
|
||
# class MochiWrapper:
|
||
# def __init__(self, *, num_workers, **actor_kwargs):
|
||
# super().__init__()
|
||
# RemoteClass = ray.remote(T2VSynthMochiModel)
|
||
# self.workers = [
|
||
# RemoteClass.options(num_gpus=1).remote(
|
||
# device_id=0, world_size=num_workers, local_rank=i, **actor_kwargs
|
||
# )
|
||
# for i in range(num_workers)
|
||
# ]
|
||
# # Ensure the __init__ method has finished on all workers
|
||
# for worker in self.workers:
|
||
# ray.get(worker.__ray_ready__.remote())
|
||
# self.is_loaded = True
|
||
|
||
# def __call__(self, args):
|
||
# work_refs = [
|
||
# worker.run.remote(args, i == 0) for i, worker in enumerate(self.workers)
|
||
# ]
|
||
|
||
# try:
|
||
# for result in work_refs[0]:
|
||
# yield ray.get(result)
|
||
|
||
# # Handle the (very unlikely) edge-case where a worker that's not the 1st one
|
||
# # fails (don't want an uncaught error)
|
||
# for result in work_refs[1:]:
|
||
# ray.get(result)
|
||
# except Exception as e:
|
||
# # Get exception from other workers
|
||
# for ref in work_refs[1:]:
|
||
# noexcept(lambda: ray.get(ref))
|
||
# raise e
|
||
|
||
def set_model_path(path):
|
||
global model_path
|
||
model_path = path
|
||
|
||
|
||
def load_model():
|
||
global model, model_path
|
||
if model is None:
|
||
#ray.init()
|
||
MOCHI_DIR = model_path
|
||
VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors"
|
||
MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml"
|
||
MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors"
|
||
|
||
model = T2VSynthMochiModel(
|
||
device_id=0,
|
||
world_size=1,
|
||
local_rank=0,
|
||
vae_stats_path=f"{MOCHI_DIR}/vae_stats.json",
|
||
vae_checkpoint_path=VAE_CHECKPOINT_PATH,
|
||
dit_config_path=MODEL_CONFIG_PATH,
|
||
dit_checkpoint_path=MODEL_CHECKPOINT_PATH,
|
||
)
|
||
|
||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||
if linear_steps is None:
|
||
linear_steps = num_steps // 2
|
||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||
quadratic_steps = num_steps - linear_steps
|
||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||
const = quadratic_coef * (linear_steps ** 2)
|
||
quadratic_sigma_schedule = [
|
||
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||
for i in range(linear_steps, num_steps)
|
||
]
|
||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||
return sigma_schedule
|
||
|
||
def generate_video(
|
||
prompt,
|
||
negative_prompt,
|
||
width,
|
||
height,
|
||
num_frames,
|
||
seed,
|
||
cfg_scale,
|
||
num_inference_steps,
|
||
):
|
||
load_model()
|
||
|
||
# sigma_schedule should be a list of floats of length (num_inference_steps + 1),
|
||
# such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
|
||
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
|
||
|
||
# cfg_schedule should be a list of floats of length num_inference_steps.
|
||
# For simplicity, we just use the same cfg scale at all timesteps,
|
||
# but more optimal schedules may use varying cfg, e.g:
|
||
# [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
|
||
cfg_schedule = [cfg_scale] * num_inference_steps
|
||
|
||
args = {
|
||
"height": height,
|
||
"width": width,
|
||
"num_frames": num_frames,
|
||
"mochi_args": {
|
||
"sigma_schedule": sigma_schedule,
|
||
"cfg_schedule": cfg_schedule,
|
||
"num_inference_steps": num_inference_steps,
|
||
"batch_cfg": False,
|
||
},
|
||
"prompt": [prompt],
|
||
"negative_prompt": [negative_prompt],
|
||
"seed": seed,
|
||
}
|
||
|
||
final_frames = None
|
||
for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1):
|
||
final_frames = frames
|
||
|
||
assert isinstance(final_frames, np.ndarray)
|
||
assert final_frames.dtype == np.float32
|
||
|
||
final_frames = rearrange(final_frames, "t b h w c -> b t h w c")
|
||
final_frames = final_frames[0]
|
||
|
||
os.makedirs("outputs", exist_ok=True)
|
||
output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4")
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
frame_paths = []
|
||
for i, frame in enumerate(final_frames):
|
||
frame = (frame * 255).astype(np.uint8)
|
||
frame_img = Image.fromarray(frame)
|
||
frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png")
|
||
frame_img.save(frame_path)
|
||
frame_paths.append(frame_path)
|
||
|
||
frame_pattern = os.path.join(tmpdir, "frame_%04d.png")
|
||
ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}"
|
||
os.system(ffmpeg_cmd)
|
||
|
||
json_path = os.path.splitext(output_path)[0] + ".json"
|
||
with open(json_path, "w") as f:
|
||
json.dump(args, f, indent=4)
|
||
|
||
return output_path
|
||
|
||
|
||
@click.command()
|
||
@click.option("--prompt", default="""
|
||
a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes,
|
||
rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and
|
||
diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider
|
||
resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spider’s body,
|
||
with a shallow depth of field to focus on the fine details of the spider’s legs and the rough surface beneath it. The atmosphere should feel immersive and alive,
|
||
with the wind subtly blowing sand grains across the frame."""
|
||
, required=False, help="Prompt for video generation.")
|
||
@click.option(
|
||
"--negative_prompt", default="", help="Negative prompt for video generation."
|
||
)
|
||
@click.option("--width", default=848, type=int, help="Width of the video.")
|
||
@click.option("--height", default=480, type=int, help="Height of the video.")
|
||
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
|
||
@click.option("--seed", default=12345, type=int, help="Random seed.")
|
||
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.")
|
||
@click.option(
|
||
"--num_steps", default=64, type=int, help="Number of inference steps."
|
||
)
|
||
@click.option("--model_dir", required=True, help="Path to the model directory.")
|
||
def generate_cli(
|
||
prompt,
|
||
negative_prompt,
|
||
width,
|
||
height,
|
||
num_frames,
|
||
seed,
|
||
cfg_scale,
|
||
num_steps,
|
||
model_dir,
|
||
):
|
||
set_model_path(model_dir)
|
||
output = generate_video(
|
||
prompt,
|
||
negative_prompt,
|
||
width,
|
||
height,
|
||
num_frames,
|
||
seed,
|
||
cfg_scale,
|
||
num_steps,
|
||
)
|
||
click.echo(f"Video generated at: {output}")
|
||
|
||
if __name__ == "__main__":
|
||
generate_cli()
|