diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 0d3edc5d2aaf..374ba0afb5f4 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv( scale=scale, soft_cap=soft_cap, ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}"