diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 67b14a7faa89f..d2b893ffff7c3 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -187,7 +187,7 @@ def generate_continuous_batched_examples(example_lens_by_batch, [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) @pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) -@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) +@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): @@ -253,15 +253,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, (8, 8, 16, 32, 16), ]), # mode examples with varied lengths - # odd chunk_size - (64, 29, 2, [(11, 4), (13, 23), (19, 22), - (21, 15)]), # irregular sizes - # large-ish chunk_size (256) (64, 256, 1, [(5, ), (1, ), (1, ), (1, )]), # irregular sizes with small sequences (64, 256, 2, [(5, 30), (1, 2), (1, 2), (1, 2)]), # irregular sizes with small sequences + + # we also need to test some large seqlen + # to catch errors with init states decay + (768, 128, 2, [(138, 225), (138, 225)]), ]) def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype): @@ -271,10 +271,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases - # TODO: the irregular chunk size cases have some issues and require higher - # tolerance. This is to be invesigated - if chunk_size not in {8, 256}: - atol, rtol = 5e-1, 5e-1 + # This test can have larger error for longer sequences + if seqlen > 256: + atol, rtol = 1e-2, 5e-3 else: atol, rtol = 5e-3, 5e-3 diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index fc2b3b25fd0a8..365139e237c66 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -290,10 +290,8 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + - (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, - mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) - and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), + dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, + mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), other=0.0).to(tl.float32) if HAS_SEQ_IDX: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index ad2853a3d8a8b..fd74cb837290b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -21,6 +21,10 @@ from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + def _mamba_chunk_scan_combined_fwd(x, dt, A, @@ -38,6 +42,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=False, dt_limit=(0.0, float("inf")), out=None): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0