mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[Misc] parametrize 'dtype' in test_flash_mla (#22641)
Signed-off-by: RUTHLESS-BOT <wujiafeng@cmbchina.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
6534d2fc97
commit
53c730286c
@ -35,11 +35,10 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
|||||||
@pytest.mark.parametrize("block_size", [64])
|
@pytest.mark.parametrize("block_size", [64])
|
||||||
@pytest.mark.parametrize("causal", [True])
|
@pytest.mark.parametrize("causal", [True])
|
||||||
@pytest.mark.parametrize("varlen", [False, True])
|
@pytest.mark.parametrize("varlen", [False, True])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||||
varlen):
|
varlen, dtype):
|
||||||
# TODO: parametrize using pytest
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -48,7 +47,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
|||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||||
f"{d=}, {dv=}, {causal=}, {varlen=}")
|
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
|
||||||
|
|
||||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||||
if varlen:
|
if varlen:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user