temporary monkey patch for torch compile Windows bug
This commit is contained in:
parent
ca5dfdf79c
commit
aa30132268
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,5 +1,33 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
#temporary patch to fix bug in Windows
|
||||
def patched_write_atomic(
|
||||
path_: str,
|
||||
content: Union[str, bytes],
|
||||
make_dirs: bool = False,
|
||||
encode_utf_8: bool = False,
|
||||
) -> None:
|
||||
# Write into temporary file first to avoid conflicts between threads
|
||||
# Avoid using a named temporary file, as those have restricted permissions
|
||||
from pathlib import Path
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
assert isinstance(
|
||||
content, (str, bytes)
|
||||
), "Only strings and byte arrays can be saved in the cache"
|
||||
path = Path(path_)
|
||||
if make_dirs:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
|
||||
write_mode = "w" if isinstance(content, str) else "wb"
|
||||
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
||||
f.write(content)
|
||||
shutil.copy2(src=tmp_path, dst=path)
|
||||
os.remove(tmp_path)
|
||||
import torch._inductor.codecache
|
||||
torch._inductor.codecache.write_atomic = patched_write_atomic
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -48,6 +76,8 @@ def unnormalize_latents(
|
||||
assert z.size(1) == mean.size(0) == std.size(0)
|
||||
return z * std.to(z) + mean.to(z)
|
||||
|
||||
|
||||
|
||||
def compute_packed_indices(
|
||||
N: int,
|
||||
text_mask: List[torch.Tensor],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user