[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman 2024-10-17 00:12:43 +08:00 committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 504 additions and 432 deletions

View File

@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor out, const at::Tensor out,
const c10::optional<at::Tensor>& bias, const c10::optional<at::Tensor>& bias,
bool silu_activation, bool silu_activation,
int64_t pad_slot_id,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt, const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt, const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) { const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.dim = dim; params.dim = dim;
params.seqlen = seqlen; params.seqlen = seqlen;
params.width = width; params.width = width;
params.pad_slot_id = pad_slot_id;
params.silu_activation = silu_activation; params.silu_activation = silu_activation;
@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase &params,
} }
at::Tensor void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_, const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &conv_states, const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &query_start_loc, const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &cache_indices, const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state, const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation) { bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type(); auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type(); auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE(cache_indices_, batch_size); CHECK_SHAPE(cache_indices_, batch_size);
} }
at::Tensor out = torch::empty_like(x); at::Tensor out = x;
ConvParamsBase params; ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_, bias_,
silu_activation, silu_activation,
pad_slot_id,
query_start_loc, query_start_loc,
cache_indices, cache_indices,
has_initial_state has_initial_state
@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream); causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
}); });
return out;
} }
at::Tensor void causal_conv1d_update(const at::Tensor &x,
causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state, const at::Tensor &conv_state,
const at::Tensor &weight, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_, const c10::optional<at::Tensor> &bias_,
bool silu_activation, bool silu_activation,
const c10::optional<at::Tensor> &cache_seqlens_, const c10::optional<at::Tensor> &cache_seqlens_,
const c10::optional<at::Tensor> &conv_state_indices_) { const c10::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type(); auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type(); auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(bias, dim); CHECK_SHAPE(bias, dim);
} }
at::Tensor out = torch::empty_like(x); at::Tensor out = x;
ConvParamsBase params; ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_, bias_,
silu_activation); silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len; params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream); causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
}); });
return out;
} }
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_> template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr); : reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; : reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id ? batch_id
: params.conv_state_indices_ptr[batch_id]; : params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride + conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride; + channel_id * params.conv_state_c_stride;

View File

@ -13,6 +13,7 @@ struct ConvParamsBase {
using index_t = uint32_t; using index_t = uint32_t;
int batch, dim, seqlen, width; int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation; bool silu_activation;
index_t x_batch_stride; index_t x_batch_stride;

View File

@ -21,6 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio; int dim_ngroups_ratio;
bool is_variable_B; bool is_variable_B;
bool is_variable_C; bool is_variable_C;
int64_t pad_slot_id;
bool delta_softplus; bool delta_softplus;

View File

@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr); : reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride; + dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const size_t seqlen, const size_t seqlen,
const size_t dstate, const size_t dstate,
const size_t n_groups, const size_t n_groups,
const size_t n_chunks,
const bool is_variable_B, const bool is_variable_B,
const bool is_variable_C, const bool is_variable_C,
// device pointers // device pointers
@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const c10::optional<at::Tensor>& query_start_loc, const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices, const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state, const c10::optional<at::Tensor>& has_initial_state,
bool varlen) { bool varlen,
int64_t pad_slot_id) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.seqlen = seqlen; params.seqlen = seqlen;
params.dstate = dstate; params.dstate = dstate;
params.n_groups = n_groups; params.n_groups = n_groups;
params.n_chunks = n_chunks;
params.dim_ngroups_ratio = dim / n_groups; params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus; params.delta_softplus = delta_softplus;
@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const c10::optional<torch::Tensor> &query_start_loc, const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &cache_indices, const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state, const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states) { const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = u.scalar_type(); auto input_type = u.scalar_type();
auto weight_type = A.scalar_type(); auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
out_z = z; out_z = z;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta; at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type); TORCH_CHECK(ssm_states.scalar_type() == input_type);
TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1); TORCH_CHECK(ssm_states.stride(-1) == 1);
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);
SSMParamsBase params; SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z, u, delta, A, B, C, out, z, out_z,
D_, D_,
delta_bias_, delta_bias_,
@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
query_start_loc, query_start_loc,
cache_indices, cache_indices,
has_initial_state, has_initial_state,
varlen varlen,
pad_slot_id
); );

View File

@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const c10::optional<torch::Tensor>& query_start_loc, const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices, const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state, const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states); const torch::Tensor& ssm_states, int64_t pad_slot_id);
at::Tensor causal_conv1d_update( void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, bool silu_activation, const c10::optional<at::Tensor>& bias_,
bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_, const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_); const c10::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states, const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc, const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices, const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state, const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation); bool silu_activation, int64_t pad_slot_id);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = int64_t; using fptr_t = int64_t;

View File

@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc," "Tensor? query_start_loc,"
"Tensor? cache_indices," "Tensor? cache_indices,"
"Tensor? has_initial_state," "Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()"); "Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def( ops.def(
@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_," "Tensor? bias_,"
"bool silu_activation," "bool silu_activation,"
"Tensor? cache_seqlens_," "Tensor? cache_seqlens_,"
"Tensor? conv_state_indices) -> Tensor"); "Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def( ops.def(
@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? query_start_loc," "Tensor? query_start_loc,"
"Tensor? cache_indices," "Tensor? cache_indices,"
"Tensor? has_initial_state," "Tensor? has_initial_state,"
"bool silu_activation) -> Tensor"); "bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif #endif

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401 from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything from vllm.utils import seed_everything
@ -114,8 +115,7 @@ def causal_conv1d_update_ref(x,
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn( def causal_conv1d_opcheck_fn(x: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None, cu_seq_len: Optional[torch.Tensor] = None,
@ -123,7 +123,7 @@ def causal_conv1d_opcheck_fn(
has_initial_state: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None, conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu", activation: Optional[str] = "silu",
): pad_slot_id: int = PAD_SLOT_ID):
""" """
x: (batch, dim, seqlen) x: (batch, dim, seqlen)
weight: (dim, width) weight: (dim, width)
@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn(
x = x.contiguous() x = x.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
opcheck(torch.ops._C.causal_conv1d_fwd, ( opcheck(torch.ops._C.causal_conv1d_fwd,
x, (x, weight, bias, conv_states, cu_seq_len, cache_indices,
weight, has_initial_state, activation in ["silu", "swish"], pad_slot_id))
bias,
conv_states,
cu_seq_len,
cache_indices,
has_initial_state,
activation in ["silu", "swish"],
))
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
seed_everything(0) seed_everything(0)
batch = 2 batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, weight = torch.randn(dim, width, device=device, dtype=itype)
width, bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
conv_state_ref = conv_state.detach().clone() conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, out = causal_conv1d_update(x,
@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
weight, weight,
bias, bias,
activation=activation) activation=activation)
out_ref = causal_conv1d_update_ref(x, out_ref = causal_conv1d_update_ref(x_ref,
conv_state_ref, conv_state_ref,
weight, weight,
bias, bias,
@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref) assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update, ( opcheck(torch.ops._C.causal_conv1d_update,
x, (x, conv_state, weight, bias, activation
conv_state, in ["silu", "swish"], None, None, PAD_SLOT_ID))
weight,
bias,
activation in ["silu", "swish"],
None,
None,
))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
@pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, # tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
seqlen, has_bias,
silu_activation, itype): silu_activation, itype):
device = "cuda" device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set )seed # set seed
seed_everything(0) seed_everything(0)
batch = 64
x = torch.randn(batch, dim, 1, device=device, dtype=itype) batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
total_entries = 10 * batch x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
conv_state = torch.randn(total_entries, conv_state = torch.randn(total_entries,
dim, dim,
width - 1, width - 1,
device=device, device=device,
dtype=itype) dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to( conv_state_for_padding_test = conv_state.clone()
dtype=torch.int32, device=device)
weight = torch.randn(dim, weight = torch.randn(dim, width, device=device, dtype=itype)
width, bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone() conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, out = causal_conv1d_update(x,
@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
weight, weight,
bias, bias,
activation=activation, activation=activation,
conv_state_indices=conv_state_indices) conv_state_indices=padded_state_indices,
out_ref = causal_conv1d_update_ref(x, pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref, conv_state_ref,
weight, weight,
bias, bias,
activation=activation) activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
opcheck(torch.ops._C.causal_conv1d_update, ( opcheck(torch.ops._C.causal_conv1d_update,
x, (x, conv_state, weight, bias, activation
conv_state, in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
weight,
bias,
activation in ["silu", "swish"],
None,
conv_state_indices,
))
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen', @pytest.mark.parametrize(
[8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096]) @pytest.mark.parametrize('dim', [64, 4096])
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, # tests correctness in case subset of the sequences are padded
itype): @pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
silu_activation, itype):
device = "cuda" device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set seed # set seed
seed_everything(0) seed_everything(0)
batch = 1
seqlens = [] seqlens = []
nsplits = 3 batch_size = 4
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append( seqlens.append(
torch.diff( torch.diff(
@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
assert sum(seqlens[-1]) == seqlen assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1]) assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0) dim=0)
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :] dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
weight_ref = weight.clone() weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
final_states = torch.randn(nsplits + 1, final_states = torch.randn(total_entries,
dim, dim,
width - 1, width - 1,
device=x.device, device=x.device,
@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
2, (cumsum.shape[0] - 1, ), 2, (cumsum.shape[0] - 1, ),
dtype=torch.bool, dtype=torch.bool,
device=x.device) device=x.device)
cache_indices = torch.randperm(cumsum.shape[0] - 1, state_indices = torch.randperm(total_entries,
dtype=torch.int32, dtype=torch.int32,
device=x.device) device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states, padded_state_indices, has_initial_states,
activation) final_states, activation, PAD_SLOT_ID)
out_ref = [] out_ref = []
out_ref_b = [] out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])): for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0] x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append( out_ref_b.append(
causal_conv1d_ref( causal_conv1d_ref(
x_s, x_s,
@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
bias_ref, bias_ref,
activation=activation, activation=activation,
return_final_states=True, return_final_states=True,
final_states_out=final_states_ref[cache_indices[i]].unsqueeze( final_states_out=final_states_ref[
0), padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[cache_indices[i]].unsqueeze(0) initial_states=final_states_ref[padded_state_indices[i]].
if has_initial_states[i] else None)) unsqueeze(0) if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref = torch.cat(out_ref, dim=0) out_ref_tensor = torch.cat(out_ref, dim=0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}") unpadded_out = out[:, :out_ref_tensor.shape[-1]]
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
print("Output state max diff"
f":{(final_states - final_states_ref).abs().max()}")
print("Output state mean diff"
f":{(final_states - final_states_ref).abs().mean()}")
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states, padded_state_indices, has_initial_states,
activation) final_states, activation)

View File

@ -5,6 +5,7 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401 from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything from vllm.utils import seed_everything
@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u,
cu_seq_len=None, cu_seq_len=None,
cache_indices=None, cache_indices=None,
has_initial_state=None, has_initial_state=None,
ssm_states=None): ssm_states=None,
pad_slot_id=PAD_SLOT_ID):
"""if return_last_state is True, returns (out, last_state) """if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). last_state has shape (batch, dim, dstate).
""" """
@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u,
# a bogus error. # a bogus error.
opcheck(torch.ops._C.selective_scan_fwd, opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
cache_indices, has_initial_state, ssm_states), cache_indices, has_initial_state, ssm_states, pad_slot_id),
test_utils=["test_schema", "test_faketensor"]) test_utils=["test_schema", "test_faketensor"])
@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype):
@pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True]) @pytest.mark.parametrize("is_variable_B", [True])
def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, # tests correctness in case subset of the sequences are padded
has_D, has_z, has_delta_bias, delta_softplus, @pytest.mark.parametrize("with_padding", [False, True])
return_last_state, seqlen, itype, wtype): def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen,
itype, wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C): if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable pytest.skip() # This config is not applicable
device = 'cuda' device = 'cuda'
@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
seqlens = [] seqlens = []
nsplits = 3 batch_size = 4
if seqlen < 10: if seqlen < 10:
nsplits = 0 batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
if with_padding and seqlen < padded_batch_size:
pytest.skip()
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append( seqlens.append(
torch.diff( torch.diff(
torch.cat( torch.cat(
[torch.tensor([-1]), eos_pos, [torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist()) torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1]) assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda() dim=0).cuda()
@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_ref = delta.clone() delta_ref = delta.clone()
out = None out = None
out_ref = None out_ref = None
prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1]))
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape, prev_state = torch.randn(prev_state_shape,
device=u.device, device=u.device,
dtype=itype, dtype=itype,
requires_grad=False) requires_grad=False)
prev_state_ref = prev_state.clone() prev_state_ref = prev_state.clone()
cache_indices = torch.randperm(cumsum.shape[0] - 1, state_indices = torch.randperm(total_entries,
dtype=torch.int32, dtype=torch.int32,
device=u.device) device=u.device)[:batch_size]
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
has_initial_state = torch.randint(0, has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ), 2, (cumsum.shape[0] - 1, ),
dtype=torch.bool, dtype=torch.bool,
device=u.device) device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices, delta_softplus, cumsum, padded_state_indices,
has_initial_state) has_initial_state)
outs_ref = [] outs_ref = []
splits = [ splits = [
@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
] ]
for i in range(len(seqlens[0])): for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_s, _ = selective_scan_ref( out_ref_s, _ = selective_scan_ref(
u_s, u_s,
delta_s, delta_s,
@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
delta_bias=delta_bias, delta_bias=delta_bias,
delta_softplus=delta_softplus, delta_softplus=delta_softplus,
return_last_state=return_last_state, return_last_state=return_last_state,
prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0) prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_state[i] else None, if has_initial_state[i] else None,
final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0)) final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
0))
outs_ref.append(out_ref_s) outs_ref.append(out_ref_s)
out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] out_ref = torch.cat(outs_ref, dim=-1)[0]
print("Output diff max", (out - out_ref[0]).max()) unpadded_out = out[:, :out_ref[0].shape[-1]]
print("Output diff mean", (out - out_ref[0]).mean()) print("Output diff max", (unpadded_out - out_ref).max())
print("Output diff mean", (unpadded_out - out_ref).mean())
print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff max", (prev_state - prev_state_ref).max())
print("Output state diff mean", (prev_state - prev_state_ref).mean()) print("Output state diff mean", (prev_state - prev_state_ref).mean())
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices, delta_softplus, cumsum, padded_state_indices,
has_initial_state, prev_state) has_initial_state, prev_state)
@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
@pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): # tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
has_z, itype):
device = "cuda" device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16: if itype == torch.bfloat16:
@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 3 batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to( state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device) dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
x = torch.randn(batch_size, dim, device=device, dtype=itype) dtype=torch.bool,
dt = torch.randn(batch_size, dim, device=device, dtype=itype) device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0 dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0 A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device) B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device) C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device) D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone() state_ref = state[state_indices, :].clone()
state_before = state.clone()
out = selective_state_update(state, out = selective_state_update(state,
x, x,
dt, dt,
@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
z=z, z=z,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices) state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref, out_ref = selective_state_update_ref(state_ref,
x, x[:batch_size],
dt, dt[:batch_size],
A, A,
B, B[:batch_size],
C, C[:batch_size],
D=D, D=D,
z=z, z=z[:batch_size],
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True) dt_softplus=True)
@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean", print("Output state diff mean",
(state[state_indices, :] - state_ref).mean()) (state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool],
state[unused_states_bool])
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
# test "real" entries
assert torch.allclose(state[state_indices, :], assert torch.allclose(state[state_indices, :],
state_ref, state_ref,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
z=z, z=z,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices) state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref, out_ref = selective_state_update_ref(state_ref,
x, x,
dt, dt,

View File

@ -1,5 +1,6 @@
import pytest import pytest
from tests.utils import multi_gpu_test
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size from vllm.worker.model_runner import _get_graph_batch_size
@ -270,6 +271,30 @@ def test_state_cleanup(
"could be related to finished_requests_ids") "could be related to finished_requests_ids")
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_jamba_distributed_produces_identical_generation(
vllm_runner, model: str, dtype: str, max_tokens: int,
example_prompts) -> None:
with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
max_tokens)
with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_outputs_tp_1,
outputs_1_lst=vllm_outputs_tp_2,
name_0="vllm_tp_1",
name_1="vllm_tp_2",
)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
def test_model_print( def test_model_print(

View File

@ -464,16 +464,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len: Optional[torch.Tensor], cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor: silu_activation: bool, pad_slot_id: int):
return torch.empty_like(x) return None
@register_fake("_C::causal_conv1d_update") @register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake( def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool, bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor], cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: conv_state_indices: Optional[torch.Tensor],
return torch.empty_like(x) pad_slot_id: int) -> None:
return None
@register_fake("_C::selective_scan_fwd") @register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
@ -485,7 +487,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
cu_seq_len: Optional[torch.Tensor], cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor]) -> None: ssm_states: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None return None
@ -800,33 +803,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
query_start_loc: Optional[torch.Tensor], query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor: silu_activation: bool, pad_slot_id: int):
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices, query_start_loc, cache_indices,
has_initial_state, silu_activation) has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update( def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor],
bias_: Optional[torch.Tensor], silu_activation: bool, silu_activation: bool,
cache_seqlens: Optional[torch.Tensor], cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: conv_state_indices: Optional[torch.Tensor],
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, pad_slot_id: int):
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens, silu_activation, cache_seqlens,
conv_state_indices) conv_state_indices, pad_slot_id)
def selective_scan_fwd( def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor], D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, query_start_loc: Optional[torch.Tensor], delta_softplus: bool,
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): has_initial_state: Optional[torch.Tensor],
ssm_states: torch.Tensor, pad_slot_id: int):
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
delta_softplus, query_start_loc, delta_softplus, query_start_loc,
cache_indices, has_initial_state, cache_indices, has_initial_state,
ssm_states) ssm_states, pad_slot_id)
# moe # moe

View File

@ -6,10 +6,10 @@ from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_fn( def causal_conv1d_fn(x: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None,
@ -17,7 +17,7 @@ def causal_conv1d_fn(
has_initial_state: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None, conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu", activation: Optional[str] = "silu",
): pad_slot_id: int = PAD_SLOT_ID):
""" """
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen sequences are concatenated from left to right for varlen
@ -37,6 +37,13 @@ def causal_conv1d_fn(
conv_states: (...,dim,width - 1) itype conv_states: (...,dim,width - 1) itype
updated inplace if provided updated inplace if provided
activation: either None or "silu" or "swish" activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen) out: (batch, dim, seqlen)
""" """
@ -46,10 +53,10 @@ def causal_conv1d_fn(
x = x.contiguous() x = x.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
cache_indices, has_initial_state, activation cache_indices, has_initial_state, activation
in ["silu", "swish"]) in ["silu", "swish"], pad_slot_id)
return out return x
def causal_conv1d_update(x: torch.Tensor, def causal_conv1d_update(x: torch.Tensor,
@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None, activation: Optional[str] = None,
cache_seqlens: Optional[torch.Tensor] = None, cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None): conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID):
""" """
x: (batch, dim) or (batch, dim, seqlen) x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1 conv_state: (batch, dim, state_len), where state_len >= width - 1
@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
If not None, the conv_state is a larger tensor along the batch dim, If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices. and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario. Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) out: (batch, dim) or (batch, dim, seqlen)
""" """
if activation not in [None, "silu", "swish"]: if activation not in [None, "silu", "swish"]:
@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
unsqueeze = x.dim() == 2 unsqueeze = x.dim() == 2
if unsqueeze: if unsqueeze:
x = x.unsqueeze(-1) x = x.unsqueeze(-1)
out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
cache_seqlens, conv_state_indices) cache_seqlens, conv_state_indices, pad_slot_id)
if unsqueeze: if unsqueeze:
out = out.squeeze(-1) x = x.squeeze(-1)
return out return x

View File

@ -1,14 +1,13 @@
# Copyright (c) 2024, Tri Dao, Albert Gu. # Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
from typing import Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from packaging import version from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
z_ptr, z_ptr,
out_ptr, out_ptr,
state_batch_indices_ptr, state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions # Matrix dimensions
batch, batch,
nheads, nheads,
@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
if HAS_Z: if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
state = tl.load(state_ptrs, mask=mask, other=0.0)
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM: if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None] state = state * dA + dB * x[:, None]
tl.store(state_ptrs,
state, mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1) out = tl.sum(state * C[None, :], axis=1)
if HAS_D: if HAS_D:
out += x * D out += x * D
@ -198,7 +201,8 @@ def selective_state_update(state,
z=None, z=None,
dt_bias=None, dt_bias=None,
dt_softplus=False, dt_softplus=False,
state_batch_indices=None): state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID):
""" """
Argument: Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate) state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@ -210,6 +214,12 @@ def selective_state_update(state,
D: (dim,) or (nheads, dim) D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim) z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim) dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return: Return:
out: (batch, dim) or (batch, nheads, dim) out: (batch, dim) or (batch, nheads, dim)
""" """
@ -276,6 +286,7 @@ def selective_state_update(state,
z, z,
out, out,
state_batch_indices, state_batch_indices,
pad_slot_id,
batch, batch,
nheads, nheads,
dim, dim,
@ -319,8 +330,7 @@ def selective_state_update(state,
return out return out
def selective_scan_fn( def selective_scan_fn(u,
u,
ssm_states, ssm_states,
delta, delta,
A, A,
@ -332,9 +342,13 @@ def selective_scan_fn(
delta_softplus=False, delta_softplus=False,
query_start_loc=None, query_start_loc=None,
cache_indices=None, cache_indices=None,
has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]: has_initial_state=None,
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
""" """
u: (dim, total_length) for varlen or (batch, dim, seqlen) u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen) delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate) A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or B: (ngroups, dstate, total_length) for varlen or
@ -357,12 +371,14 @@ def selective_scan_fn(
indicate if the ssm_state at the corresponding index should be indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes used as initial state. Not providing argument assumes
there's no initial state there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns returns
output: (dim, total_length) for varlen or (batch, dim, seqlen) output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement supports inplace replacement
last_state has shape (batch, dim, dstate).
supports inplace replacement if ssm_state was provided
""" """
if u.stride(-1) != 1: if u.stride(-1) != 1:
u = u.contiguous() u = u.contiguous()
@ -387,7 +403,7 @@ def selective_scan_fn(
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
query_start_loc, cache_indices, has_initial_state, query_start_loc, cache_indices, has_initial_state,
ssm_states) ssm_states, pad_slot_id)
if z is None: if z is None:
return delta # output written inplace to delta return delta # output written inplace to delta

View File

@ -1,6 +1,5 @@
# coding=utf-8 # coding=utf-8
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
@ -29,7 +28,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader) composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -41,13 +41,6 @@ from .interfaces import HasInnerState, SupportsLoRA
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class JambaMambaMixer(nn.Module): class JambaMambaMixer(nn.Module):
""" """
@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
**selective** state spaces) **selective** state spaces)
""" """
def __init__(self, config: JambaConfig, layer_idx): def __init__(self, config: JambaConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv self.conv_kernel_size = config.mamba_d_conv
@ -129,8 +121,8 @@ class JambaMambaMixer(nn.Module):
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor, attn_metadata: AttentionMetadata,
ssm_state: torch.Tensor): mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -153,17 +145,18 @@ class JambaMambaMixer(nn.Module):
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=conv_state, conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0, has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc) query_start_loc=attn_metadata.query_start_loc)
else: else:
hidden_states = causal_conv1d_update( hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1), hidden_states.transpose(0, 1),
conv_state, mamba_cache_params.conv_state,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
) conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1) hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
@ -188,7 +181,7 @@ class JambaMambaMixer(nn.Module):
and attn_metadata.context_lens_tensor is not None: and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn( scan_outputs = selective_scan_fn(
hidden_states, hidden_states,
ssm_state, mamba_cache_params.ssm_state,
discrete_time_step, discrete_time_step,
self.A, self.A,
B.transpose(-2, -1), B.transpose(-2, -1),
@ -197,11 +190,12 @@ class JambaMambaMixer(nn.Module):
gate, gate,
time_proj_bias, time_proj_bias,
delta_softplus=True, delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0, has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc) query_start_loc=attn_metadata.query_start_loc)
else: else:
scan_outputs = selective_state_update( scan_outputs = selective_state_update(
ssm_state, mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1), hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1), discrete_time_step.transpose(0, 1),
self.A, self.A,
@ -211,7 +205,7 @@ class JambaMambaMixer(nn.Module):
gate.transpose(0, 1), gate.transpose(0, 1),
time_proj_bias, time_proj_bias,
dt_softplus=True, dt_softplus=True,
) state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1) scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection # 4. Final linear projection
@ -292,7 +286,7 @@ class JambaMambaDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.config = config self.config = config
self.mamba = JambaMambaMixer(config, layer_idx) self.mamba = JambaMambaMixer(config)
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@ -307,8 +301,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
conv_state: torch.Tensor, mamba_cache_params: MambaCacheParams,
ssm_state: torch.Tensor,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -318,8 +311,8 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, hidden_states = self.mamba(hidden_states, attn_metadata,
ssm_state) mamba_cache_params)
# Fully Connected # Fully Connected
hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual) hidden_states, residual)
@ -476,17 +469,14 @@ class JambaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, mamba_cache_params: MambaCacheParams,
ssm_state: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
kv_cache = None kv_cache = None
current_ssm_state = None layer_mamba_cache_params = None
current_conv_state = None
if isinstance(layer, JambaAttentionDecoderLayer): if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) // kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period] self.config.attn_layer_period]
@ -494,8 +484,8 @@ class JambaModel(nn.Module):
current_state_layer = i - (1 + current_state_layer = i - (1 +
(i - self.config.attn_layer_offset) (i - self.config.attn_layer_offset)
// self.config.attn_layer_period) // self.config.attn_layer_period)
current_ssm_state = ssm_state[current_state_layer] layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_conv_state = conv_state[current_state_layer] current_state_layer)
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
@ -503,9 +493,7 @@ class JambaModel(nn.Module):
kv_cache=kv_cache, kv_cache=kv_cache,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
conv_state=current_conv_state, mamba_cache_params=layer_mamba_cache_params)
ssm_state=current_ssm_state,
)
hidden_states, _ = self.final_layernorm(hidden_states, residual) hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states return hidden_states
@ -588,13 +576,16 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape()) *self._get_mamba_cache_shape())
(
mamba_cache_tensors = self.mamba_cache.current_run_tensors( mamba_cache_tensors,
input_ids, attn_metadata, **kwargs) state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0], attn_metadata, mamba_cache_params)
mamba_cache_tensors[1])
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):

View File

@ -27,7 +27,8 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader) composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree) IsAttentionFree)
from vllm.model_executor.models.mamba_cache import MambaCacheManager from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -110,8 +111,8 @@ class MambaMixer(nn.Module):
self.activation = config.hidden_act self.activation = config.hidden_act
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor, attn_metadata: AttentionMetadata,
ssm_state: torch.Tensor): mamba_cache_params: MambaCacheParams):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -134,17 +135,18 @@ class MambaMixer(nn.Module):
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=conv_state, conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0, has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc) query_start_loc=attn_metadata.query_start_loc)
else: else:
hidden_states = causal_conv1d_update( hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1), hidden_states.transpose(0, 1),
conv_state, mamba_cache_params.conv_state,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
) conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1) hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
@ -168,7 +170,7 @@ class MambaMixer(nn.Module):
and attn_metadata.context_lens_tensor is not None: and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn( scan_outputs = selective_scan_fn(
hidden_states, hidden_states,
ssm_state, mamba_cache_params.ssm_state,
discrete_time_step, discrete_time_step,
self.A, self.A,
B.transpose(-2, -1), B.transpose(-2, -1),
@ -177,11 +179,12 @@ class MambaMixer(nn.Module):
gate, gate,
time_proj_bias, time_proj_bias,
delta_softplus=True, delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0, has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc) query_start_loc=attn_metadata.query_start_loc)
else: else:
scan_outputs = selective_state_update( scan_outputs = selective_state_update(
ssm_state, mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1), hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1), discrete_time_step.transpose(0, 1),
self.A, self.A,
@ -191,7 +194,7 @@ class MambaMixer(nn.Module):
gate.transpose(0, 1), gate.transpose(0, 1),
time_proj_bias, time_proj_bias,
dt_softplus=True, dt_softplus=True,
) state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1) scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection # 4. Final linear projection
@ -221,8 +224,7 @@ class MambaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
conv_state: torch.Tensor, mamba_cache_params: MambaCacheParams,
ssm_state: torch.Tensor,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
@ -231,8 +233,8 @@ class MambaDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, hidden_states = self.mixer(hidden_states, attn_metadata,
ssm_state) mamba_cache_params)
return hidden_states, residual return hidden_states, residual
@ -275,25 +277,20 @@ class MambaModel(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, mamba_cache_params: MambaCacheParams,
ssm_state: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embeddings(input_ids) hidden_states = self.embeddings(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
current_ssm_state = ssm_state[i]
current_conv_state = conv_state[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
residual=residual, residual=residual,
conv_state=current_conv_state, mamba_cache_params=mamba_cache_params.at_layer_idx(i))
ssm_state=current_ssm_state,
)
hidden_states, _ = self.norm_f(hidden_states, residual) hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states return hidden_states
@ -347,12 +344,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.lm_head.weight.dtype, self.config.num_hidden_layers, self.lm_head.weight.dtype, self.config.num_hidden_layers,
max_batch_size, *self._get_mamba_cache_shape()) max_batch_size, *self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors( (
input_ids, attn_metadata, **kwargs) mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_tensors[0], mamba_cache_params)
mamba_cache_tensors[1])
return hidden_states return hidden_states

View File

@ -1,8 +1,22 @@
from typing import Dict, List, Optional from dataclasses import dataclass
from typing import Dict, List
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager: class MambaCacheManager:
@ -24,6 +38,7 @@ class MambaCacheManager:
# Maps between the request id and a dict that maps between the seq_id # Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache # and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
def current_run_tensors(self, input_ids: torch.Tensor, def current_run_tensors(self, input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, **kwargs): attn_metadata: AttentionMetadata, **kwargs):
@ -36,30 +51,43 @@ class MambaCacheManager:
finished_requests_ids = kwargs["finished_requests_ids"] finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids) self._release_finished_requests(finished_requests_ids)
mamba_cache_tensors = self._prepare_current_run_mamba_cache( state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids) request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else: else:
# CUDA graph capturing runs # CUDA graph capturing runs
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] (mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return mamba_cache_tensors return (mamba_cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
""" """
Copy the relevant Mamba cache into the CUDA graph input buffer Copy the relevant state_indices into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
""" """
assert all( assert all(
key in kwargs key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"] finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids) self._release_finished_requests(finished_requests_ids)
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, state_indices = self._prepare_current_run_mamba_cache(
finished_requests_ids) request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
""" """
@ -67,13 +95,10 @@ class MambaCacheManager:
The buffer is used to maintain the Mamba Cache during the CUDA graph The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs. replay runs.
""" """
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
def _swap_mamba_cache(self, from_index: int, to_index: int): device="cuda")
assert len(self.mamba_cache) > 0 return (self.mamba_cache, state_indices_tensor)
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, from_index: int, to_index: int): def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0 assert len(self.mamba_cache) > 0
@ -81,142 +106,53 @@ class MambaCacheManager:
cache_t[:, to_index].copy_(cache_t[:, from_index], cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True) non_blocking=True)
def _move_out_if_already_occupied(self, index: int, def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
all_occupied_indices: List[int]): finished_requests_ids) -> int:
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
""" """
Assign (req_id,seq_id) pair to a `destination_index` index, if Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index. already occupied, move the occupying index to a free index.
""" """
all_occupied_indices = self._get_all_occupied_indices() if cur_rid in finished_requests_ids:
if cur_rid not in self.mamba_cache_indices_mapping: # set as pad, do not allocate destination index
self._move_out_if_already_occupied( return PAD_SLOT_ID
index=destination_index, elif cur_rid not in self.mamba_cache_indices_mapping:
all_occupied_indices=all_occupied_indices) destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = { self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index seq_id: destination_index
} }
return destination_index
elif seq_id not in (seq_ids2indices := elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]): self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have # parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already # already happened, so we copy the
# existing cache into the siblings seq_ids caches # existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied( index_exists = next(iter(seq_ids2indices.values()))
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices # case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists, self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index) to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][ self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index seq_id] = destination_index
return destination_index
else: else:
# already exists # already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[ return self.mamba_cache_indices_mapping[cur_rid][seq_id]
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache( def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]): finished_requests_ids: List[str]) -> List[int]:
running_indices = [] return [
request_ids_to_seq_ids_flatten = [ self._assign_seq_id_to_cache_index(req_id, seq_id,
(req_id, seq_id) finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items() for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids for seq_id in seq_ids
] ]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache index for requests that run
# and finish right after
continue
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_requests(self, def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]): finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids: for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping: if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id) self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")