temporary monkey patch for torch compile Windows bug
This commit is contained in:
parent
ca5dfdf79c
commit
aa30132268
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,5 +1,33 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
#temporary patch to fix bug in Windows
|
||||||
|
def patched_write_atomic(
|
||||||
|
path_: str,
|
||||||
|
content: Union[str, bytes],
|
||||||
|
make_dirs: bool = False,
|
||||||
|
encode_utf_8: bool = False,
|
||||||
|
) -> None:
|
||||||
|
# Write into temporary file first to avoid conflicts between threads
|
||||||
|
# Avoid using a named temporary file, as those have restricted permissions
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
assert isinstance(
|
||||||
|
content, (str, bytes)
|
||||||
|
), "Only strings and byte arrays can be saved in the cache"
|
||||||
|
path = Path(path_)
|
||||||
|
if make_dirs:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
|
||||||
|
write_mode = "w" if isinstance(content, str) else "wb"
|
||||||
|
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
||||||
|
f.write(content)
|
||||||
|
shutil.copy2(src=tmp_path, dst=path)
|
||||||
|
os.remove(tmp_path)
|
||||||
|
import torch._inductor.codecache
|
||||||
|
torch._inductor.codecache.write_atomic = patched_write_atomic
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -48,6 +76,8 @@ def unnormalize_latents(
|
|||||||
assert z.size(1) == mean.size(0) == std.size(0)
|
assert z.size(1) == mean.size(0) == std.size(0)
|
||||||
return z * std.to(z) + mean.to(z)
|
return z * std.to(z) + mean.to(z)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_packed_indices(
|
def compute_packed_indices(
|
||||||
N: int,
|
N: int,
|
||||||
text_mask: List[torch.Tensor],
|
text_mask: List[torch.Tensor],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user