temporary monkey patch for torch compile Windows bug

This commit is contained in:
kijai 2024-10-25 19:41:10 +03:00
parent ca5dfdf79c
commit aa30132268
4 changed files with 31 additions and 1 deletions

View File

@ -1,5 +1,33 @@
import json
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
#temporary patch to fix bug in Windows
def patched_write_atomic(
path_: str,
content: Union[str, bytes],
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
from pathlib import Path
import os
import shutil
import threading
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
write_mode = "w" if isinstance(content, str) else "wb"
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
f.write(content)
shutil.copy2(src=tmp_path, dst=path)
os.remove(tmp_path)
import torch._inductor.codecache
torch._inductor.codecache.write_atomic = patched_write_atomic
import torch
import torch.nn.functional as F
@ -48,6 +76,8 @@ def unnormalize_latents(
assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z)
def compute_packed_indices(
N: int,
text_mask: List[torch.Tensor],