diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index e812869..9aeb81a 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -46,7 +46,9 @@ except ImportError: backends = [] if torch.cuda.get_device_properties(0).major < 7: backends.append(SDPBackend.MATH) + backends.append(SDPBackend.EFFICIENT_ATTENTION) if torch.cuda.get_device_properties(0).major >= 9.0: + backends.append(SDPBackend.EFFICIENT_ATTENTION) backends.append(SDPBackend.CUDNN_ATTENTION) else: backends.append(SDPBackend.EFFICIENT_ATTENTION)