From dbb036cf612a3c9943254182af40597ec107be08 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 15 Apr 2025 01:35:38 -0400 Subject: [PATCH] [Bugfix] Fix tests/kernels/test_mamba_ssm_ssd.py (#16623) Signed-off-by: Tyler Michael Smith --- tests/kernels/test_mamba_ssm_ssd.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py index 8f23a9b216e98..ee908105f557f 100644 --- a/tests/kernels/test_mamba_ssm_ssd.py +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -5,6 +5,8 @@ import torch import torch.nn.functional as F from einops import rearrange, repeat +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + _seq_idx_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform @@ -160,14 +162,14 @@ def generate_continous_batched_examples(example_lens_by_batch, # get the metadata cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) - sed_idx = torch.zeros(cu_seqlens[-1], + seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device) for i, (srt, end) in enumerate(zip( cu_seqlens, cu_seqlens[1:], )): - sed_idx[srt:end] = i + seq_idx[srt:end] = i # for cont batch if IND_E is None: @@ -177,7 +179,7 @@ def generate_continous_batched_examples(example_lens_by_batch, IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], - cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @pytest.mark.parametrize("itype", @@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, exhausted: dict = {} # map: eg -> boolean indicating example is exhausted states = None - for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + for Y_min, cu_seqlens, seq_idx, (A, dt, X, B, C) in generate_continous_batched_examples( cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype): + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + Y, new_states = mamba_chunk_scan_combined( X, dt, @@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, chunk_size, D=None, cu_seqlens=cu_seqlens, - seq_idx=sed_idx, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, return_varlen_states=True, initial_states=states, )