diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index d9023490d7fc..4647b97c4771 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -183,7 +183,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("seqlen", [1, 3]) @@ -265,7 +265,7 @@ def test_causal_conv1d_update_with_batch_gather( @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096]) +@pytest.mark.parametrize("seqlen", [8, 249, 4096]) @pytest.mark.parametrize("dim", [64, 4096]) @pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("batch", [4, 10]) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index d23daefa7b43..25934c409744 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -25,7 +25,6 @@ from vllm.utils import update_environment_variables (64, 1), (64, 2), (64, 4), # hidden_size be divisible by num_gpus - (100, 5), # and n_groups must divide hidden_size ], ) @pytest.mark.parametrize("dtype", [torch.float16]) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 9a6137239ebf..c59fc7af0c89 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -229,8 +229,8 @@ def selective_scan_opcheck_fn( @pytest.mark.parametrize("wtype", [torch.float32]) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("seqlen", [128, 1024, 4096]) @pytest.mark.parametrize("has_delta_bias", [True]) @pytest.mark.parametrize("delta_softplus", [True]) @pytest.mark.parametrize("has_z", [True]) @@ -238,7 +238,7 @@ def selective_scan_opcheck_fn( @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -@pytest.mark.parametrize("scan_chunks", [1, 2, 3]) +@pytest.mark.parametrize("scan_chunks", [1, 3]) def test_selective_scan( is_variable_B, is_variable_C, @@ -375,9 +375,9 @@ def test_selective_scan( ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dstate", [16, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) def test_selective_state_update(dim, dstate, has_z, itype): device = "cuda" @@ -413,7 +413,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("wtype", [torch.float32]) @pytest.mark.parametrize("itype", [torch.float32]) -@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096]) @pytest.mark.parametrize("return_last_state", [True]) @pytest.mark.parametrize("has_delta_bias", [True]) @pytest.mark.parametrize("delta_softplus", [True]) @@ -589,9 +589,9 @@ def test_selective_scan_varlen( ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dstate", [16, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) @@ -679,11 +679,11 @@ def test_selective_state_update_with_batch_indices( assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("tie_hdim", [False, True]) -@pytest.mark.parametrize("ngroups", [1, 2, 4]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("ngroups", [1, 4]) +@pytest.mark.parametrize("dstate", [16, 64]) @pytest.mark.parametrize("dim", [2048, 4096]) def test_selective_state_update_with_heads_with_batch_indices( dim, dstate, ngroups, has_z, tie_hdim, itype diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 57dcb789e97b..0b0b82e484a1 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -188,9 +188,9 @@ def generate_continuous_batched_examples( ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) -@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [4, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 32, 128]) @pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): # this tests the kernels on a single example (bs=1) @@ -254,15 +254,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) -@pytest.mark.parametrize("n_heads", [4, 8, 13]) -@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize("itype", [torch.float32]) +@pytest.mark.parametrize("n_heads", [4, 8]) +@pytest.mark.parametrize("d_head", [5, 16, 32]) @pytest.mark.parametrize( "seq_len_chunk_size_cases", [ # small-ish chunk_size (8) (64, 8, 2, [(64, 32), (64, 32)]), - (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary ( 64, @@ -270,16 +269,7 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it 2, [(4, 4), (4, 4), (4, 4), (4, 4)], ), # chunk_size larger than cont batches - ( - 64, - 8, - 5, - [ - (64, 32, 16, 8, 8), - (8, 16, 32, 16, 8), - (8, 8, 16, 32, 16), - ], - ), # mode examples with varied lengths + (64, 8, 5, [(64, 32, 16, 8, 8)]), # large-ish chunk_size (256) (64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences ( @@ -359,11 +349,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, @pytest.mark.parametrize("chunk_size", [8, 256]) @pytest.mark.parametrize( "seqlens", - [ - (16, 2, 8, 13), - (270, 88, 212, 203), - (16, 20), - ], + [(16, 20), (270, 88, 212, 203)], ) def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): # This test verifies the correctness of the chunked prefill implementation