diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc deleted file mode 100644 index 1be05b3..0000000 Binary files a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc and /dev/null differ diff --git a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc deleted file mode 100644 index 68ad0e8..0000000 Binary files a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc and /dev/null differ diff --git a/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc deleted file mode 100644 index 4de68dd..0000000 Binary files a/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc and /dev/null differ diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index ab92689..33cfaa7 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,5 +1,33 @@ import json -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union + +#temporary patch to fix bug in Windows +def patched_write_atomic( + path_: str, + content: Union[str, bytes], + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + from pathlib import Path + import os + import shutil + import threading + assert isinstance( + content, (str, bytes) + ), "Only strings and byte arrays can be saved in the cache" + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + shutil.copy2(src=tmp_path, dst=path) + os.remove(tmp_path) +import torch._inductor.codecache +torch._inductor.codecache.write_atomic = patched_write_atomic import torch import torch.nn.functional as F @@ -48,6 +76,8 @@ def unnormalize_latents( assert z.size(1) == mean.size(0) == std.size(0) return z * std.to(z) + mean.to(z) + + def compute_packed_indices( N: int, text_mask: List[torch.Tensor],