[CI] Retry flaky fp8 cutlass mla tests (#24536)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-09-09 20:33:10 -07:00 committed by GitHub
parent 41f160b974
commit 7e7db04310
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -49,7 +49,13 @@ CUTLASS_MLA_UNSUPPORTED_REASON = \
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
"torch_dtype",
[
torch.bfloat16,
# fp8 can have occasional precision-related failures.
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2))
])
@torch.inference_mode()
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
causal, varlen, torch_dtype):