diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 681a55db5..97c274f99 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -335,7 +335,7 @@ def vae_attention(): if model_management.xformers_enabled_vae(): logging.info("Using xformers attention in VAE") return xformers_attention - elif model_management.pytorch_attention_enabled_vae(): + elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention in VAE") return pytorch_attention else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..57831b0bc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -354,8 +354,8 @@ try: if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): - if any((a in arch) for a in ["gfx1201"]): - ENABLE_PYTORCH_ATTENTION = True + if any((a in arch) for a in ["gfx1200", "gfx1201"]): + ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 SUPPORT_FP8_OPS = True @@ -1221,11 +1221,6 @@ def pytorch_attention_enabled(): global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION -def pytorch_attention_enabled_vae(): - if is_amd(): - return False # enabling pytorch attention on AMD currently causes crash when doing high res - return pytorch_attention_enabled() - def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: