diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 68b49b0..ba0d67d 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -43,7 +43,7 @@ except ImportError: SAGEATTN_IS_AVAILABLE = False backends = [] -if torch.cuda.get_device_properties(0).major <= 7.5: +if torch.cuda.get_device_properties(0).major < 7: backends.append(SDPBackend.MATH) if torch.cuda.get_device_properties(0).major >= 9.0: backends.append(SDPBackend.CUDNN_ATTENTION)