diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 98edc959957d0..50e48aad6ebaa 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("max_seq_len", [1, 2, 4]) +def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 5e-2, 1.5e-1 + if torch.version.hip: + atol *= 2 + # set seed + current_platform.seed_everything(0) + batch_size = 4 + token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) + total_tokens = int(token_counts.sum().item()) + cu_seqlens = torch.tensor( + [0] + torch.cumsum(token_counts, dim=0).tolist(), + dtype=torch.int32, + device=device, + ) + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(total_tokens, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(total_tokens, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(total_tokens, dstate, device=device) + C = torch.randn(total_tokens, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.detach().clone() + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + cu_seqlens=cu_seqlens, + ) + + out_ref_list = [] + for seq_idx in range(batch_size): + start_idx = cu_seqlens[seq_idx].item() + end_idx = cu_seqlens[seq_idx + 1].item() + num_tokens = end_idx - start_idx + for token_idx in range(num_tokens): + idx = start_idx + token_idx + out_ref_list.append( + selective_state_update_ref( + state_ref[seq_idx : seq_idx + 1], + x[idx : idx + 1], + dt[idx : idx + 1], + A, + B[idx : idx + 1], + C[idx : idx + 1], + D=D, + z=z[idx : idx + 1] if has_z else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + ) + out_ref = torch.cat(out_ref_list, dim=0) + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("wtype", [torch.float32]) @pytest.mark.parametrize("itype", [torch.float32]) @pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096]) @@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +@pytest.mark.parametrize("max_seq_len", [2, 4]) +def test_selective_state_update_with_num_accepted_tokens( + dim, dstate, has_z, itype, max_seq_len +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 5e-2, 1.5e-1 + if torch.version.hip: + atol *= 2 + + current_platform.seed_everything(0) + batch_size = 4 + + tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) + total_tokens = int(tokens_per_seq.sum().item()) + + num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device) + num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens + num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted + + cu_seqlens = torch.tensor( + [0] + torch.cumsum(tokens_per_seq, dim=0).tolist(), + dtype=torch.int32, + device=device, + ) + + total_state_slots = 50 + state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) + + state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + initial_state_slots = torch.randint( + 0, 15, (batch_size,), device=device, dtype=torch.int32 + ) + for seq_idx in range(batch_size): + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] + + dst_state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + slot_offset = 15 + dst_slots_map = {} + for seq_idx in range(batch_size): + for token_idx in range(tokens_per_seq[seq_idx].item()): + dst_state_batch_indices[seq_idx, token_idx] = slot_offset + dst_slots_map[(seq_idx, token_idx)] = slot_offset + slot_offset += 1 + + x = torch.randn(total_tokens, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(total_tokens, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(total_tokens, dstate, device=device) + C = torch.randn(total_tokens, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + + state_ref_intermediate = {} + out_ref_list = [] + + for seq_idx in range(batch_size): + seq_start = cu_seqlens[seq_idx].item() + seq_end = cu_seqlens[seq_idx + 1].item() + num_tokens = seq_end - seq_start + + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + initial_slot = state_batch_indices[seq_idx, token_pos].item() + state_seq = state[initial_slot : initial_slot + 1].clone() + + for token_idx in range(num_tokens): + global_idx = seq_start + token_idx + + out_token = selective_state_update_ref( + state_seq, + x[global_idx : global_idx + 1], + dt[global_idx : global_idx + 1], + A, + B[global_idx : global_idx + 1], + C[global_idx : global_idx + 1], + D=D, + z=z[global_idx : global_idx + 1] if has_z else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + out_ref_list.append(out_token) + state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone() + + out_ref = torch.cat(out_ref_list, dim=0) + + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + cu_seqlens=cu_seqlens, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + num_accepted_tokens=num_accepted_tokens, + pad_slot_id=PAD_SLOT_ID, + ) + + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + for seq_idx in range(batch_size): + num_tokens = tokens_per_seq[seq_idx].item() + for token_idx in range(num_tokens): + dst_slot = dst_slots_map[(seq_idx, token_idx)] + state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0) + assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +@pytest.mark.parametrize("max_seq_len", [2, 4]) +def test_selective_state_update_varlen_with_num_accepted( + dim, dstate, has_z, itype, max_seq_len +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 5e-2, 1.5e-1 + if torch.version.hip: + atol *= 2 + + current_platform.seed_everything(0) + batch_size = 4 + + tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) + total_tokens = int(tokens_per_seq.sum().item()) + + num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device) + num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens + num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted + + cu_seqlens = torch.tensor( + [0] + torch.cumsum(tokens_per_seq, dim=0).tolist(), + dtype=torch.int32, + device=device, + ) + + total_state_slots = 50 + state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) + + state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + + initial_state_slots = torch.randint( + 0, 15, (batch_size,), device=device, dtype=torch.int32 + ) + for seq_idx in range(batch_size): + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] + + dst_state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + + slot_offset = 15 + dst_slots_map = {} + for seq_idx in range(batch_size): + for token_idx in range(tokens_per_seq[seq_idx].item()): + dst_state_batch_indices[seq_idx, token_idx] = slot_offset + dst_slots_map[(seq_idx, token_idx)] = slot_offset + slot_offset += 1 + + x = torch.randn(total_tokens, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(total_tokens, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(total_tokens, dstate, device=device) + C = torch.randn(total_tokens, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + + state_ref_intermediate = {} + + for seq_idx in range(batch_size): + seq_start = cu_seqlens[seq_idx].item() + seq_end = cu_seqlens[seq_idx + 1].item() + num_tokens = seq_end - seq_start + + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + initial_slot = state_batch_indices[seq_idx, token_pos].item() + state_seq = state[initial_slot : initial_slot + 1].clone() + + for token_idx in range(num_tokens): + global_idx = seq_start + token_idx + + selective_state_update_ref( + state_seq, + x[global_idx : global_idx + 1], + dt[global_idx : global_idx + 1], + A, + B[global_idx : global_idx + 1], + C[global_idx : global_idx + 1], + D=D, + z=z[global_idx : global_idx + 1] if has_z else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + + state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone() + + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + cu_seqlens=cu_seqlens, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + num_accepted_tokens=num_accepted_tokens, + pad_slot_id=PAD_SLOT_ID, + ) + + for seq_idx in range(batch_size): + num_tokens = tokens_per_seq[seq_idx].item() + + for token_idx in range(num_tokens): + dst_slot = dst_slots_map[(seq_idx, token_idx)] + state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0) + + assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 53fd5d5458b09..800f8bd840792 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -36,10 +36,14 @@ else: is not None } ) +@triton.heuristics( + {"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None} +) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None}) @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} ) -@triton.jit +@triton.jit(do_not_specialize=["N"]) def _selective_scan_update_kernel( # Pointers to matrices state_ptr, @@ -55,8 +59,10 @@ def _selective_scan_update_kernel( state_batch_indices_ptr, dst_state_batch_indices_ptr, pad_slot_id, + num_accepted_tokens_ptr, + cu_seqlens_ptr, # Matrix dimensions - batch, + N, nheads, dim, dstate, @@ -91,6 +97,10 @@ def _selective_scan_update_kernel( stride_out_batch, stride_out_head, stride_out_dim, + stride_state_indices_batch, + stride_state_indices_T, + stride_dst_state_indices_batch, + stride_dst_state_indices_T, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, @@ -99,22 +109,50 @@ def _selective_scan_update_kernel( HAS_D: tl.constexpr, HAS_Z: tl.constexpr, HAS_STATE_BATCH_INDICES: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_VARLEN: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) + if IS_VARLEN: + bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64) + eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64) + seq_len = eos - bos + + if seq_len == 0: + return + else: + bos = pid_b + seq_len = 1 + + state_ptr_base = state_ptr + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate # is taken from the state_batch_indices_ptr Otherwise, the state coordinate # is the same as the batch id. if HAS_STATE_BATCH_INDICES: - dst_state_batch_indices_ptr += pid_b - dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) - dst_state_ptr = state_ptr + ( - dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + if IS_SPEC_DECODING: + num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64) + init_token_idx = tl.maximum(num_accepted - 1, 0) + else: + init_token_idx = 0 + + dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch + if not IS_SPEC_DECODING: + dst_state_batch_idx = tl.load( + dst_state_batch_indices_ptr + + init_token_idx * stride_dst_state_indices_T + ).to(tl.int64) + dst_state_ptr = state_ptr + ( + dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) + + state_batch_indices_ptr += ( + pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T ) - state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: @@ -123,86 +161,112 @@ def _selective_scan_update_kernel( ) state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head - x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + x_ptr += bos * stride_x_batch + pid_h * stride_x_head + dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: - z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + z_ptr += bos * stride_z_batch + pid_h * stride_z_head + out_ptr += bos * stride_out_batch + pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + ( offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate ) - dst_state_ptrs = dst_state_ptr + ( - offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate - ) - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if not IS_SPEC_DECODING: + dst_state_ptrs = dst_state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + ( - offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - ) - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_D: D_ptrs = D_ptr + offs_m * stride_D_dim - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - state = tl.load(state_ptrs, mask=mask, other=0.0) + A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load( - A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 - ).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - else: - dt = tl.load(dt_ptr).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptr).to(tl.float32) - if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix + for i_t in range(seq_len): + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix - dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt - state = state * dA + dB * x[:, None] + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - tl.store(dst_state_ptrs, state, mask=mask) - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + if IS_SPEC_DECODING: + dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T + token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64) + if token_dst_idx != pad_slot_id: + token_dst_ptrs = ( + state_ptr_base + + token_dst_idx * stride_state_batch + + pid_h * stride_state_head + + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + tl.store( + token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask + ) + + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + x_ptr += stride_x_batch + dt_ptr += stride_dt_batch + B_ptr += stride_B_batch + C_ptr += stride_C_batch + out_ptr += stride_out_batch + if HAS_Z: + z_ptr += stride_z_batch + + if not IS_SPEC_DECODING: + tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask) def selective_state_update( @@ -220,6 +284,8 @@ def selective_state_update( dst_state_batch_indices=None, pad_slot_id=PAD_SLOT_ID, out=None, + num_accepted_tokens=None, + cu_seqlens=None, ): """ Argument: @@ -240,6 +306,11 @@ def selective_state_update( indices 0 and 3 out: Preallocated ssm output tensor. Assume same shape as x. In-place updated. + num_accepted_tokens: (batch,) + number of accepted tokens from previous verification step, + tells the kernel which initial state to use + cu_seqlens: (batch,) + length per sequence, for variable length in speculative decoding cases """ if state.dim() == 3: state = state.unsqueeze(1) @@ -261,9 +332,26 @@ def selective_state_update( dt_bias = dt_bias.unsqueeze(0) if out.dim() == 2: out = out.unsqueeze(1) + if num_accepted_tokens is not None: + assert state_batch_indices is not None and state_batch_indices.dim() == 2 + assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 + if state_batch_indices is not None and state_batch_indices.dim() == 1: + state_batch_indices = state_batch_indices.unsqueeze(1) + if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1: + dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1) _, nheads, dim, dstate = state.shape batch = x.shape[0] + if cu_seqlens is not None: + N = len(cu_seqlens) - 1 + # Only used to verify the shape of + # state_batch_indices and dst_state_batch_indices + max_seqlen = ( + state_batch_indices.size(-1) if state_batch_indices is not None else 1 + ) + else: + N = batch + max_seqlen = 1 assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape @@ -279,16 +367,30 @@ def selective_state_update( if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: - assert state_batch_indices.shape == (batch,) + assert state_batch_indices.shape[0] >= N + assert state_batch_indices.shape[1] >= max_seqlen if dst_state_batch_indices is not None: - assert dst_state_batch_indices.shape == (batch,) + assert dst_state_batch_indices.shape[0] >= N + assert dst_state_batch_indices.shape[1] >= max_seqlen else: # revert to the default behavior of in-place state updates dst_state_batch_indices = state_batch_indices assert out.shape == x.shape + if num_accepted_tokens is not None: + assert num_accepted_tokens.shape == (N,) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads) z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + state_batch_indices_strides = ( + (state_batch_indices.stride(0), state_batch_indices.stride(1)) + if state_batch_indices is not None + else (0, 0) + ) + dst_state_batch_indices_strides = ( + (dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1)) + if dst_state_batch_indices is not None + else (0, 0) + ) # We don't want autotune since it will overwrite the state # We instead tune by hand. BLOCK_SIZE_M, num_warps = ( @@ -321,7 +423,9 @@ def selective_state_update( state_batch_indices, dst_state_batch_indices, pad_slot_id, - batch, + num_accepted_tokens, + cu_seqlens, + N, nheads, dim, dstate, @@ -353,6 +457,10 @@ def selective_state_update( out.stride(0), out.stride(1), out.stride(2), + state_batch_indices_strides[0], + state_batch_indices_strides[1], + dst_state_batch_indices_strides[0], + dst_state_batch_indices_strides[1], dt_softplus, tie_hdim, BLOCK_SIZE_M,