mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[Misc] parametrize 'dtype' in test_flash_mla (#22641)
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
6534d2fc97
commit
53c730286c
@ -35,11 +35,10 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
varlen):
|
||||
# TODO: parametrize using pytest
|
||||
dtype = torch.bfloat16
|
||||
varlen, dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device(device)
|
||||
@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
random.seed(0)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}")
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
|
||||
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user