From 00b31a36a2d0de6d197a473280b2304d482714af Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Sun, 2 Nov 2025 14:16:23 +0200 Subject: [PATCH] [V1] [Hybrid] Mamba1 Automatic Prefix Caching (#26377) Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> --- csrc/mamba/mamba_ssm/selective_scan.h | 8 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 134 +++++++++++++++--- csrc/ops.h | 24 ++-- csrc/torch_bindings.cpp | 6 +- tests/kernels/mamba/test_mamba_ssm.py | 15 ++ .../models/language/generation/test_hybrid.py | 34 ++--- vllm/_custom_ops.py | 8 ++ vllm/config/model.py | 6 + .../layers/mamba/mamba_mixer.py | 91 ++++++++---- .../layers/mamba/ops/mamba_ssm.py | 24 +++- vllm/model_executor/models/config.py | 2 +- vllm/model_executor/models/jamba.py | 21 ++- vllm/model_executor/models/mamba.py | 9 +- vllm/v1/attention/backends/mamba1_attn.py | 111 ++++++++++++--- vllm/v1/attention/backends/mamba2_attn.py | 40 +----- vllm/v1/attention/backends/mamba_attn.py | 62 +++++++- 16 files changed, 442 insertions(+), 153 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 13c6178941cf8..7d22dd8b84a39 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -24,6 +24,8 @@ struct SSMParamsBase { int64_t pad_slot_id; bool delta_softplus; + bool cache_enabled; + int block_size; index_t A_d_stride; index_t A_dstate_stride; @@ -46,8 +48,9 @@ struct SSMParamsBase { index_t out_z_batch_stride; index_t out_z_d_stride; index_t ssm_states_batch_stride; - index_t ssm_states_dim_stride; + index_t ssm_states_dim_stride; index_t ssm_states_dstate_stride; + index_t cache_indices_stride; // Common data pointers. void *__restrict__ A_ptr; @@ -66,6 +69,9 @@ struct SSMParamsBase { void *__restrict__ cache_indices_ptr; void *__restrict__ has_initial_state_ptr; + void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write + void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write + void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use }; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d534e138d26d6..fb2a2e5789999 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); - const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; // cache_index == params.pad_slot_id is defined as padding, so we exit early if (cache_index == params.pad_slot_id){ return; @@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - typename Ktraits::state_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + - cache_index * params.ssm_states_batch_stride + - dim_id * kNRows * params.ssm_states_dim_stride; + + typename Ktraits::state_t *ssm_states; + if (params.cache_enabled) { + // APC mode: ssm_states points to the base, we'll use absolute cache slots later + ssm_states = reinterpret_cast(params.ssm_states_ptr) + + dim_id * kNRows * params.ssm_states_dim_stride; + } else { + // Non-APC mode: offset by cache_index as before + ssm_states = reinterpret_cast(params.ssm_states_ptr) + + cache_index * params.ssm_states_batch_stride + + dim_id * kNRows * params.ssm_states_dim_stride; + } float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // } constexpr int kChunkSize = kNThreads * kNItems; - const int n_chunks = (seqlen + 2048 - 1) / 2048; + + // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility + const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048; + const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size; + + const int* batch_cache_indices = cache_indices != nullptr ? + cache_indices + batch_id * params.cache_indices_stride : nullptr; + const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ? + reinterpret_cast(params.block_idx_first_scheduled_token_ptr) : nullptr; + const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ? + reinterpret_cast(params.block_idx_last_scheduled_token_ptr) : nullptr; + const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ? + reinterpret_cast(params.initial_state_idx_ptr) : nullptr; + + const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index; + for (int chunk = 0; chunk < n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; @@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); + smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1)); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { for (int i = 0; i < kNItems; ++i) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } // Initialize running total - - scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0); + scan_t running_prefix; + if (chunk > 0) { + running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE]; + } else { + // Load initial state + if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) { + size_t state_offset = load_cache_slot * params.ssm_states_batch_stride + + r * params.ssm_states_dim_stride + + state_idx * params.ssm_states_dstate_stride; + running_prefix = make_float2(1.0, float(ssm_states[state_offset])); + } else if (has_initial_state) { + // Non-APC mode: load from current batch position + running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride])); + } else { + // No initial state + running_prefix = make_float2(1.0, 0.0); + } + } SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // There's a syncthreads in the scan op, so we don't need to sync here. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - if (chunk == n_chunks - 1) { + smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix; + + // Store state at the end of each chunk when cache is enabled + if (params.cache_enabled && batch_cache_indices != nullptr) { + + size_t cache_slot; + if (chunk == n_chunks - 1) { + cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]]; + } else { + cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk]; + } + + size_t state_offset = cache_slot * params.ssm_states_batch_stride + + r * params.ssm_states_dim_stride + + state_idx * params.ssm_states_dstate_stride; + + ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y); + } else if (!params.cache_enabled && chunk == n_chunks - 1) { + // Non-APC mode: store only final state at current batch position ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y); } } @@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } } - input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; __syncthreads(); @@ -346,7 +401,9 @@ template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { #ifndef USE_ROCM - if (params.seqlen <= 128) { + if (params.cache_enabled && params.block_size == 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); + } else if (params.seqlen <= 128) { selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 256) { selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream); @@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream); } #else - if (params.seqlen <= 256) { + if (params.cache_enabled && params.block_size == 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream); + } else if (params.seqlen <= 256) { selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream); @@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const std::optional& D, const std::optional& delta_bias, const torch::Tensor ssm_states, - bool has_z, + bool has_z, bool delta_softplus, const std::optional& query_start_loc, const std::optional& cache_indices, const std::optional& has_initial_state, bool varlen, - int64_t pad_slot_id) { + int64_t pad_slot_id, + int64_t block_size, + const std::optional &block_idx_first_scheduled_token, + const std::optional &block_idx_last_scheduled_token, + const std::optional &initial_state_idx) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + // Set cache parameters - cache is enabled if we have direct cache writing params + params.cache_enabled = block_idx_first_scheduled_token.has_value(); + params.block_size = static_cast(block_size); + + // Set direct cache writing pointers + params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr; + params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr; + params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); @@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_d_stride = out.stride(0); params.ssm_states_batch_stride = ssm_states.stride(0); - params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dim_stride = ssm_states.stride(1); params.ssm_states_dstate_stride = ssm_states.stride(2); + params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0; + } else{ if (!is_variable_B) { @@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_d_stride = out.stride(1); params.ssm_states_batch_stride = ssm_states.stride(0); - params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dim_stride = ssm_states.stride(1); params.ssm_states_dstate_stride = ssm_states.stride(2); + + params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0; } } @@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early - int64_t pad_slot_id) { + int64_t pad_slot_id, + int64_t block_size, + const std::optional &block_idx_first_scheduled_token, + const std::optional &block_idx_last_scheduled_token, + const std::optional &initial_state_idx) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, auto cache_indices_ = cache_indices.value(); TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); TORCH_CHECK(cache_indices_.is_cuda()); - CHECK_SHAPE(cache_indices_, batch_size); + + // cache_indices can be either 1D (batch_size,) for non-APC mode + // or 2D (batch_size, max_positions) for APC mode + const bool is_apc_mode = block_idx_first_scheduled_token.has_value(); + if (is_apc_mode) { + TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode"); + TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size"); + } else { + CHECK_SHAPE(cache_indices_, batch_size); + } } @@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, cache_indices, has_initial_state, varlen, - pad_slot_id + pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx ); diff --git a/csrc/ops.h b/csrc/ops.h index 0bed7492f6616..3f5cb799b774c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -321,17 +321,19 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, std::optional const& scale_ub); -void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, - const torch::Tensor& A, const torch::Tensor& B, - const torch::Tensor& C, - const std::optional& D_, - const std::optional& z_, - const std::optional& delta_bias_, - bool delta_softplus, - const std::optional& query_start_loc, - const std::optional& cache_indices, - const std::optional& has_initial_state, - const torch::Tensor& ssm_states, int64_t pad_slot_id); +void selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const std::optional& D_, + const std::optional& z_, + const std::optional& delta_bias_, bool delta_softplus, + const std::optional& query_start_loc, + const std::optional& cache_indices, + const std::optional& has_initial_state, + const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, + const std::optional& block_idx_first_scheduled_token, + const std::optional& block_idx_last_scheduled_token, + const std::optional& initial_state_idx); torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8f091a429fbef..9c0f524dcab11 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -611,7 +611,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? cache_indices," "Tensor? has_initial_state," "Tensor! ssm_states," - "int pad_slot_id) -> ()"); + "int pad_slot_id," + "int block_size," + "Tensor? block_idx_first_scheduled_token," + "Tensor? block_idx_last_scheduled_token," + "Tensor? initial_state_idx) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); // Hadamard transforms diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index c59fc7af0c897..98edc959957d0 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -179,6 +179,10 @@ def selective_scan_opcheck_fn( has_initial_state=None, ssm_states=None, pad_slot_id=PAD_SLOT_ID, + block_size=2048, + block_idx_first_scheduled_token=None, + block_idx_last_scheduled_token=None, + initial_state_idx=None, ): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). @@ -223,6 +227,10 @@ def selective_scan_opcheck_fn( has_initial_state, ssm_states, pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, ), test_utils=["test_schema", "test_faketensor"], ) @@ -338,6 +346,11 @@ def test_selective_scan( has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) if c > 0 else None, + pad_slot_id=PAD_SLOT_ID, + block_size=2048, + block_idx_first_scheduled_token=None, + block_idx_last_scheduled_token=None, + initial_state_idx=None, ) outs.append(out) if len(outs) > 1: @@ -372,6 +385,7 @@ def test_selective_scan( delta_bias=delta_bias, delta_softplus=delta_softplus, ssm_states=state, + block_size=2048, ) @@ -586,6 +600,7 @@ def test_selective_scan_varlen( padded_state_indices, has_initial_state, prev_state, + block_size=2048, ) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index fd2df329f17f9..681b380e6a155 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -19,6 +19,8 @@ pytestmark = pytest.mark.hybrid_model # meaning that it will be used in all tests in this file # The rest of the models will only be tested by test_models +APC_MULTIPLY_BY = 300 + SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", @@ -380,7 +382,7 @@ def _get_vLLM_output( return outs, vllm_model -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -410,10 +412,8 @@ def test_apc_single_prompt( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. - generated_prompts = [MULTIPLE * example_prompts[0]] + generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -446,7 +446,7 @@ def test_apc_single_prompt( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -476,10 +476,8 @@ def test_apc_single_prompt_block_align_alignment( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. This custom prompt is used, as it causes the most issues - generated_prompts = ["The president of the United States is " * MULTIPLE] + generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -528,7 +526,7 @@ def test_apc_single_prompt_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -558,10 +556,8 @@ def test_apc_multiple_prompts_all_cached_outputs( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. - generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -595,7 +591,7 @@ def test_apc_multiple_prompts_all_cached_outputs( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -625,12 +621,12 @@ def test_apc_multiple_prompts_block_align_alignment( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. This custom prompt is used, as it causes the most issues prompt_text = "The president of the United States is " prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] - generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] + generated_prompts = [ + prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets + ] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( @@ -679,7 +675,7 @@ def test_apc_multiple_prompts_block_align_alignment( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("n_repetitions", [2]) # If num_logprobs is set to -1, then the stringent version @@ -709,10 +705,8 @@ def test_apc_multiple_prompts_partial_cached_outputs( check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore ) - MULTIPLE = 300 - # Sample prompts. - generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts] max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) vllm_runner_kwargs = _get_vllm_runner_params( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9110b0573fc92..61cf54fcfa39a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1719,6 +1719,10 @@ def selective_scan_fwd( has_initial_state: torch.Tensor | None, ssm_states: torch.Tensor, pad_slot_id: int, + block_size: int = 1024, + block_idx_first_scheduled_token: torch.Tensor | None = None, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, ): torch.ops._C.selective_scan_fwd( u, @@ -1735,6 +1739,10 @@ def selective_scan_fwd( has_initial_state, ssm_states, pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, ) diff --git a/vllm/config/model.py b/vllm/config/model.py index 082f90653f5af..2e80df4311035 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1483,6 +1483,12 @@ class ModelConfig: if chunk_size is None: # used by e.g. Mamba2, NemotronH, Zamba chunk_size = getattr(self.hf_text_config, "chunk_size", None) + + # Since Mamba1 does not have a chunk notion + # we use a default chunk size of 1024. + if chunk_size is None: + chunk_size = 2048 + return chunk_size def get_multimodal_config(self) -> MultiModalConfig: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a9a0c216474bc..b6345b8af7f0a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -241,18 +241,21 @@ class MambaMixer(MambaBase, CustomOp): forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + assert self.cache_config is not None + mamba_block_size = self.cache_config.mamba_block_size + prefix_caching_enabled = self.cache_config.enable_prefix_caching + if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] - mamba1_metadata = attn_metadata - assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) - query_start_loc = mamba1_metadata.query_start_loc - state_indices_tensor = mamba1_metadata.state_indices_tensor + assert isinstance(attn_metadata, Mamba1AttentionMetadata) + query_start_loc_p = attn_metadata.query_start_loc_p + state_indices_tensor = attn_metadata.state_indices_tensor self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_states = mamba1_metadata.has_initial_states - num_padded_decodes = mamba1_metadata.num_padded_decodes + has_initial_states_p = attn_metadata.has_initial_states_p + num_padded_decodes = attn_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -279,12 +282,8 @@ class MambaMixer(MambaBase, CustomOp): hidden_states_BC, gate, state_indices_tensor, - query_start_loc, - has_initial_states, num_prefill_tokens, - num_decode_tokens, num_prefills, - num_decodes, num_padded_decodes, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p @@ -293,8 +292,34 @@ class MambaMixer(MambaBase, CustomOp): gate_d = prefill_decode_split.gate_d state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d - query_start_loc_p = prefill_decode_split.query_start_loc_p - has_initial_states_p = prefill_decode_split.has_initial_states_p + + if prefix_caching_enabled: + block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( + torch.split( + attn_metadata.block_idx_last_computed_token, + [num_decodes, num_prefills], + dim=0, + ) + ) + block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = ( + torch.split( + attn_metadata.block_idx_last_scheduled_token, + [num_decodes, num_prefills], + dim=0, + ) + ) + + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p + ) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + else: + block_idx_last_computed_token_d = None + block_idx_last_computed_token_p = None + block_idx_last_scheduled_token_d = None + block_idx_last_scheduled_token_p = None + block_idx_first_scheduled_token_p = None + num_computed_tokens_p = None ssm_outputs = [] @@ -309,6 +334,11 @@ class MambaMixer(MambaBase, CustomOp): has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, query_start_loc=query_start_loc_p, + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, + num_computed_tokens=num_computed_tokens_p, + block_size_to_align=mamba_block_size, ) # 3. State Space Model sequence transformations. discrete_time_step_p, B_p, C_p = self._ssm_transform( @@ -331,10 +361,24 @@ class MambaMixer(MambaBase, CustomOp): cache_indices=state_indices_tensor_p, has_initial_state=has_initial_states_p, query_start_loc=query_start_loc_p, + block_size=mamba_block_size, + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, ) ssm_outputs.append(scan_out_p) if has_decode: + if prefix_caching_enabled: + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) + else: + state_indices_tensor_d_input = state_indices_tensor_d + state_indices_tensor_d_output = state_indices_tensor_d # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), @@ -343,6 +387,8 @@ class MambaMixer(MambaBase, CustomOp): self.conv1d.bias, self.activation, conv_state_indices=state_indices_tensor_d, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, + initial_state_idx=block_idx_last_computed_token_d, ).transpose(0, 1) # 3. State Space Model sequence transformation. @@ -364,7 +410,8 @@ class MambaMixer(MambaBase, CustomOp): gate_d.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor_d, + state_batch_indices=state_indices_tensor_d_input, + dst_state_batch_indices=state_indices_tensor_d_output, out=scan_outputs_d, ) scan_outputs_d = scan_outputs_d.transpose(0, 1) @@ -423,20 +470,14 @@ class PrefillDecodeSplit(NamedTuple): gate_d: torch.Tensor state_indices_tensor_p: torch.Tensor state_indices_tensor_d: torch.Tensor - query_start_loc_p: torch.Tensor | None - has_initial_states_p: torch.Tensor | None def split_batch_to_prefill_and_decode( hidden_states_BC: torch.Tensor, gate: torch.Tensor, state_indices_tensor: torch.Tensor, - query_start_loc: torch.Tensor, - has_initial_states: torch.Tensor | None, num_prefill_tokens: int, - num_decode_tokens: int, num_prefills: int, - num_decodes: int, num_padded_decodes: int, ) -> PrefillDecodeSplit: num_actual_tokens = num_prefill_tokens + num_padded_decodes @@ -457,16 +498,6 @@ def split_batch_to_prefill_and_decode( [num_padded_decodes, num_prefills], dim=0, ) - query_start_loc_p = ( - query_start_loc[-num_prefills - 1 :] - num_padded_decodes - if num_prefills > 0 - else None - ) - has_initial_states_p = ( - has_initial_states[-num_prefills:] - if (has_initial_states is not None and num_prefills > 0) - else None - ) return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, @@ -475,8 +506,6 @@ def split_batch_to_prefill_and_decode( gate_d=gate_d, state_indices_tensor_p=state_indices_tensor_p, state_indices_tensor_d=state_indices_tensor_d, - query_start_loc_p=query_start_loc_p, - has_initial_states_p=has_initial_states_p, ) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 8722eb9a7b22f..53fd5d5458b09 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -375,6 +375,10 @@ def selective_scan_fn( cache_indices=None, has_initial_state=None, pad_slot_id=PAD_SLOT_ID, + block_size=1024, + block_idx_first_scheduled_token=None, + block_idx_last_scheduled_token=None, + initial_state_idx=None, ) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) @@ -397,7 +401,10 @@ def selective_scan_fn( x.shape=(dim,17) cache_indices: (batch) int32 A tensor with each cell is a correspondent - input and output ssm_state index + input and output ssm_state indices + - Without APC: (batch,) - single state index per batch item + - With APC: (batch, max_positions) - cache block indices for read/write + Each non-zero value indicates a cache block to load from and/or write to. has_initial_state: (batch) bool A tensor populated with ones and zeros, indicate if the ssm_state at the corresponding index should be @@ -408,6 +415,17 @@ def selective_scan_fn( that will not be processed, for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 + block_size: int + The block size to align the cached states to + block_idx_first_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the first + cache block to be filled is located. + block_idx_last_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the last cache block + to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into cache_indices, where the cache block + containing the initial state is located. returns output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement @@ -448,6 +466,10 @@ def selective_scan_fn( has_initial_state, ssm_states, pad_slot_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, ) if z is None: diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 7150977e9266b..5dda2ec97875f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -299,7 +299,7 @@ class MambaModelConfig(VerifyAndUpdateConfig): if model_config.supports_mamba_prefix_caching: logger.info( "Warning: Prefix caching is currently enabled. " - "Its support for Mamba2 layers is experimental. " + "Its support for Mamba layers is experimental. " "Please report any issues you may observe." ) else: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index f8a87cf6965f8..ba95021b0b542 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -38,7 +38,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsMambaPrefixCaching, + SupportsPP, +) from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -454,7 +460,14 @@ class JambaModel(nn.Module): return loaded_params -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): +class JambaForCausalLM( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsMambaPrefixCaching, +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={".self_attn.": ".", ".A_log": ".A"}, ) @@ -477,12 +490,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHyb def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Jamba currently does not support prefix caching" - ) super().__init__() self.config = config diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fb145289fbfe9..f684203f6d35e 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -29,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( HasInnerState, IsAttentionFree, + SupportsMambaPrefixCaching, SupportsPP, ) from vllm.sequence import IntermediateTensors @@ -193,15 +194,13 @@ class MambaModel(nn.Module): return loaded_params -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): +class MambaForCausalLM( + nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching +): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Mamba does not support prefix caching" - ) super().__init__() self.config = config diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 30c63e0ded8e7..909af09be255a 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -7,11 +7,13 @@ import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, split_decodes_and_prefills, ) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -22,32 +24,41 @@ class Mamba1AttentionBackend(AttentionBackend): @dataclass class Mamba1AttentionMetadata: - query_start_loc: torch.Tensor - context_lens_tensor: torch.Tensor + query_start_loc_p: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states: torch.Tensor | None + has_initial_states_p: torch.Tensor | None num_prefills: int num_prefill_tokens: int num_decodes: int num_decode_tokens: int num_padded_decodes: int + block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] + block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] + block_idx_last_computed_token: torch.Tensor # shape: [batch,] + num_computed_tokens_p: torch.Tensor # shape: [batch,] + class Mamba1AttentionMetadataBuilder( BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] ): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert isinstance(kv_cache_spec, MambaSpec) + def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> Mamba1AttentionMetadata: - query_start_loc = common_attn_metadata.query_start_loc - - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] - context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - query_start_loc.device - ) + num_reqs = common_attn_metadata.num_reqs num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -55,32 +66,100 @@ class Mamba1AttentionMetadataBuilder( ) ) - has_initial_states = None + has_initial_states_p = None + query_start_loc_p = None padded_decodes = num_decodes + num_computed_tokens, num_computed_tokens_p = None, None + block_idx_first_scheduled_token = None + block_idx_first_scheduled_token_p = None + + # TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here. + # We should consolidate this code + if self.vllm_config.cache_config.enable_prefix_caching: + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor + mamba_block_size = self.kv_cache_spec.block_size + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + ) = self._compute_prefix_caching_block_indices( + common_attn_metadata, mamba_block_size + ) + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + block_idx_last_scheduled_token = None + block_idx_last_computed_token = None if num_prefills > 0: - has_initial_states = context_lens_tensor > 0 + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) + has_initial_states_cpu = ( + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) + has_initial_states_p = has_initial_states_cpu.to( + common_attn_metadata.query_start_loc.device + ) + + if self.vllm_config.cache_config.enable_prefix_caching: + assert num_computed_tokens is not None + num_computed_tokens_p = num_computed_tokens[ + num_reqs - num_prefills : num_reqs + ] + assert block_idx_first_scheduled_token is not None + block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ + num_reqs - num_prefills : num_reqs + ] + elif ( num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.full_cuda_graph ): - state_indices_for_decode = state_indices_tensor[:num_decodes] padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( - state_indices_for_decode, non_blocking=True + state_indices_tensor, non_blocking=True ) state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor[num_decodes:] = PAD_SLOT_ID + if self.vllm_config.cache_config.enable_prefix_caching: + self.block_idx_last_scheduled_token[:num_decodes].copy_( + block_idx_last_scheduled_token, non_blocking=True + ) + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + :padded_decodes + ] + block_idx_last_scheduled_token[num_decodes:] = 0 + + self.block_idx_last_computed_token[:num_decodes].copy_( + block_idx_last_computed_token, non_blocking=True + ) + block_idx_last_computed_token = self.block_idx_last_computed_token[ + :padded_decodes + ] + block_idx_last_computed_token[num_decodes:] = 0 + return Mamba1AttentionMetadata( - query_start_loc=query_start_loc, - context_lens_tensor=context_lens_tensor, - has_initial_states=has_initial_states, + query_start_loc_p=query_start_loc_p, + has_initial_states_p=has_initial_states_p, state_indices_tensor=state_indices_tensor, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_padded_decodes=padded_decodes, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token=block_idx_last_computed_token, + num_computed_tokens_p=num_computed_tokens_p, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index f9d2426eaf632..4bc1057333a50 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -147,27 +147,6 @@ class Mamba2AttentionMetadataBuilder( assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models" ) - if self.vllm_config.cache_config.enable_prefix_caching: - self.state_indices_tensor = torch.empty( - ( - self.decode_cudagraph_max_bs, - cdiv( - vllm_config.model_config.max_model_len, kv_cache_spec.block_size - ), - ), - dtype=torch.int32, - device=device, - ) - self.block_idx_last_scheduled_token = torch.empty( - (self.decode_cudagraph_max_bs,), - dtype=torch.int32, - device=device, - ) - self.block_idx_last_computed_token = torch.empty( - (self.decode_cudagraph_max_bs,), - dtype=torch.int32, - device=device, - ) def build( self, @@ -202,20 +181,13 @@ class Mamba2AttentionMetadataBuilder( num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( self.device ) - # Block index of the last computed token - block_idx_last_computed_token = ( - cdiv(num_computed_tokens, mamba_block_size) - 1 + ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + ) = self._compute_prefix_caching_block_indices( + common_attn_metadata, mamba_block_size ) - # which is <= block index for the first scheduled token - block_idx_first_scheduled_token = ( - cdiv(num_computed_tokens + 1, mamba_block_size) - 1 - ) - # which is <= block index of the last scheduled token - block_idx_last_scheduled_token = ( - cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 - ) - # -1 in case it's non-computed and causes later issues with indexing - block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 52f26a9e61cab..49d7d6c31b9a0 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,6 +7,7 @@ from typing import ClassVar, TypeVar import torch from vllm.config import VllmConfig +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -38,11 +39,35 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): self.vllm_config.scheduler_config.max_num_seqs, self.compilation_config.max_cudagraph_capture_size, ) - self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs,), - dtype=torch.int32, - device=device, - ) + + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor = torch.empty( + ( + self.decode_cudagraph_max_bs, + cdiv( + self.vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size, + ), + ), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + else: + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -61,3 +86,30 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): m.max_query_len = 1 # decode-only return self.build(0, m) + + def _compute_prefix_caching_block_indices( + self, + common_attn_metadata: CommonAttentionMetadata, + mamba_block_size: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + # Block index of the last computed token + block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 + # which is <= block index for the first scheduled token + block_idx_first_scheduled_token = ( + cdiv(num_computed_tokens + 1, mamba_block_size) - 1 + ) + # which is <= block index of the last scheduled token + block_idx_last_scheduled_token = ( + cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 + ) + # -1 in case it's non-computed and causes later issues with indexing + block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) + + return ( + block_idx_last_computed_token, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + )