mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Set OCL_SET_SVM_SIZE on AMD.
This commit is contained in:
parent
913f86b727
commit
a28293eba7
@ -63,18 +63,22 @@ def cuda_malloc_supported():
|
||||
return True
|
||||
|
||||
|
||||
version = ""
|
||||
|
||||
try:
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
ver_file = os.path.join(folder, "version.py")
|
||||
if os.path.isfile(ver_file):
|
||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
except:
|
||||
pass
|
||||
|
||||
if not args.cuda_malloc:
|
||||
try:
|
||||
version = ""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
ver_file = os.path.join(folder, "version.py")
|
||||
if os.path.isfile(ver_file):
|
||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
|
||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||
env_var += ",backend:cudaMallocAsync"
|
||||
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||
|
||||
def get_torch_version_noimport():
|
||||
return str(version)
|
||||
|
||||
3
main.py
3
main.py
@ -167,6 +167,9 @@ if __name__ == "__main__":
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
import cuda_malloc
|
||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user