diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 33cfaa7..7504afa 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -24,10 +24,13 @@ def patched_write_atomic( write_mode = "w" if isinstance(content, str) else "wb" with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: f.write(content) - shutil.copy2(src=tmp_path, dst=path) + shutil.copy2(src=tmp_path, dst=path) #to allow overwriting cache files os.remove(tmp_path) -import torch._inductor.codecache -torch._inductor.codecache.write_atomic = patched_write_atomic +try: + import torch._inductor.codecache + torch._inductor.codecache.write_atomic = patched_write_atomic +except: + pass import torch import torch.nn.functional as F