diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 436cb430817e3..58da01f0ebbf3 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -188,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: - (7, 5)) + monkeypatch.setattr(torch.cuda, + "get_device_capability", + lambda _=None: (7, 5)) backend = get_attn_backend(16, torch.float16, None, 16, False) assert backend.get_name() != STR_FLASH_ATTN_VAL