Merge 02d15cc85f5f34322938786788ac1ac2b6cf5961 into fd271dedfde6e192a1f1a025521070876e89e04a

This commit is contained in:
Kacper Michajłow 2025-12-08 18:22:17 +08:00 committed by GitHub
commit d9d1597ba8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 8 deletions

View File

@ -335,7 +335,7 @@ def vae_attention():
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE") logging.info("Using xformers attention in VAE")
return xformers_attention return xformers_attention
elif model_management.pytorch_attention_enabled_vae(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE") logging.info("Using pytorch attention in VAE")
return pytorch_attention return pytorch_attention
else: else:

View File

@ -354,8 +354,8 @@ try:
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True SUPPORT_FP8_OPS = True
@ -1221,11 +1221,6 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_enabled_vae():
if is_amd():
return False # enabling pytorch attention on AMD currently causes crash when doing high res
return pytorch_attention_enabled()
def pytorch_attention_flash_attention(): def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION: