[Misc] parametrize 'dtype' in test_flash_mla (#22641)

Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
RUTHLESS-BOT 2025-08-13 04:31:48 +08:00 committed by GitHub
parent 6534d2fc97
commit 53c730286c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,11 +35,10 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen): varlen, dtype):
# TODO: parametrize using pytest
dtype = torch.bfloat16
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.set_default_device(device) torch.set_default_device(device)
@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
random.seed(0) random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}") f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen: if varlen: