mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 21:42:26 +08:00
[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
31b96d1c64
commit
47043eb678
@ -232,7 +232,6 @@ endif()
|
|||||||
|
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
|
||||||
"csrc/cache_kernels.cu"
|
"csrc/cache_kernels.cu"
|
||||||
"csrc/attention/paged_attention_v1.cu"
|
"csrc/attention/paged_attention_v1.cu"
|
||||||
"csrc/attention/paged_attention_v2.cu"
|
"csrc/attention/paged_attention_v2.cu"
|
||||||
|
|||||||
@ -1,656 +0,0 @@
|
|||||||
// clang-format off
|
|
||||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
|
||||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
|
||||||
#include <torch/all.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include "causal_conv1d.h"
|
|
||||||
#include <c10/util/BFloat16.h>
|
|
||||||
#include <c10/util/Half.h>
|
|
||||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
|
||||||
|
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
#include <cub/block/block_store.cuh>
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
namespace cub = hipcub;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "static_switch.h"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
||||||
|
|
||||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
|
||||||
if (ITYPE == at::ScalarType::Half) { \
|
|
||||||
using input_t = at::Half; \
|
|
||||||
using weight_t = at::Half; \
|
|
||||||
__VA_ARGS__(); \
|
|
||||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
|
||||||
using input_t = at::BFloat16; \
|
|
||||||
using weight_t = at::BFloat16; \
|
|
||||||
__VA_ARGS__(); \
|
|
||||||
} else if (ITYPE == at::ScalarType::Float) { \
|
|
||||||
using input_t = float; \
|
|
||||||
using weight_t = float; \
|
|
||||||
__VA_ARGS__(); \
|
|
||||||
} else { \
|
|
||||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
|
||||||
// sizes
|
|
||||||
const size_t batch,
|
|
||||||
const size_t dim,
|
|
||||||
const size_t seqlen,
|
|
||||||
const size_t width,
|
|
||||||
// device pointers
|
|
||||||
const at::Tensor x,
|
|
||||||
const at::Tensor weight,
|
|
||||||
const at::Tensor out,
|
|
||||||
const std::optional<at::Tensor>& bias,
|
|
||||||
bool silu_activation,
|
|
||||||
int64_t pad_slot_id,
|
|
||||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
|
||||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
|
||||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
|
||||||
|
|
||||||
// Reset the parameters
|
|
||||||
memset(¶ms, 0, sizeof(params));
|
|
||||||
|
|
||||||
params.batch = batch;
|
|
||||||
params.dim = dim;
|
|
||||||
params.seqlen = seqlen;
|
|
||||||
params.width = width;
|
|
||||||
params.pad_slot_id = pad_slot_id;
|
|
||||||
|
|
||||||
params.silu_activation = silu_activation;
|
|
||||||
|
|
||||||
// Set the pointers and strides.
|
|
||||||
params.x_ptr = x.data_ptr();
|
|
||||||
params.weight_ptr = weight.data_ptr();
|
|
||||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
|
||||||
params.out_ptr = out.data_ptr();
|
|
||||||
// All stride are in elements, not bytes.
|
|
||||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
|
||||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
|
||||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
|
||||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
|
||||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
|
||||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
|
||||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
|
||||||
params.weight_c_stride = weight.stride(0);
|
|
||||||
params.weight_width_stride = weight.stride(1);
|
|
||||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
|
||||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
|
||||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
|
||||||
const std::optional<at::Tensor> &bias_,
|
|
||||||
const std::optional<at::Tensor> &conv_states,
|
|
||||||
const std::optional<at::Tensor> &query_start_loc,
|
|
||||||
const std::optional<at::Tensor> &cache_indices,
|
|
||||||
const std::optional<at::Tensor> &has_initial_state,
|
|
||||||
bool silu_activation,
|
|
||||||
// used to identify padding entries if cache_indices provided
|
|
||||||
// in case of padding, the kernel will return early
|
|
||||||
int64_t pad_slot_id) {
|
|
||||||
auto input_type = x.scalar_type();
|
|
||||||
auto weight_type = weight.scalar_type();
|
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
|
||||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
|
||||||
|
|
||||||
TORCH_CHECK(x.is_cuda());
|
|
||||||
TORCH_CHECK(weight.is_cuda());
|
|
||||||
|
|
||||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
|
||||||
const auto sizes = x.sizes();
|
|
||||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
|
||||||
const int dim = varlen ? sizes[0] : sizes[1];
|
|
||||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
|
||||||
const int width = weight.size(-1);
|
|
||||||
if (varlen){
|
|
||||||
CHECK_SHAPE(x, dim, seqlen);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
|
||||||
}
|
|
||||||
CHECK_SHAPE(weight, dim, width);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if (bias_.has_value()) {
|
|
||||||
auto bias = bias_.value();
|
|
||||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
|
||||||
TORCH_CHECK(bias.is_cuda());
|
|
||||||
TORCH_CHECK(bias.stride(-1) == 1);
|
|
||||||
CHECK_SHAPE(bias, dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (has_initial_state.has_value()) {
|
|
||||||
auto has_initial_state_ = has_initial_state.value();
|
|
||||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
|
||||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
|
||||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (query_start_loc.has_value()) {
|
|
||||||
auto query_start_loc_ = query_start_loc.value();
|
|
||||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
|
||||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (cache_indices.has_value()) {
|
|
||||||
auto cache_indices_ = cache_indices.value();
|
|
||||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
|
||||||
TORCH_CHECK(cache_indices_.is_cuda());
|
|
||||||
CHECK_SHAPE(cache_indices_, batch_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor out = x;
|
|
||||||
|
|
||||||
ConvParamsBase params;
|
|
||||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
|
||||||
bias_,
|
|
||||||
silu_activation,
|
|
||||||
pad_slot_id,
|
|
||||||
query_start_loc,
|
|
||||||
cache_indices,
|
|
||||||
has_initial_state
|
|
||||||
);
|
|
||||||
|
|
||||||
if (conv_states.has_value()) {
|
|
||||||
auto conv_states_ = conv_states.value();
|
|
||||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
|
||||||
TORCH_CHECK(conv_states_.is_cuda());
|
|
||||||
params.conv_states_ptr = conv_states_.data_ptr();
|
|
||||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
|
||||||
params.conv_states_c_stride = conv_states_.stride(1);
|
|
||||||
params.conv_states_l_stride = conv_states_.stride(2);
|
|
||||||
} else {
|
|
||||||
params.conv_states_ptr = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
||||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
|
||||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void causal_conv1d_update(const at::Tensor &x,
|
|
||||||
const at::Tensor &conv_state,
|
|
||||||
const at::Tensor &weight,
|
|
||||||
const std::optional<at::Tensor> &bias_,
|
|
||||||
bool silu_activation,
|
|
||||||
const std::optional<at::Tensor> &cache_seqlens_,
|
|
||||||
const std::optional<at::Tensor> &conv_state_indices_,
|
|
||||||
// used to identify padding entries if cache_indices provided
|
|
||||||
// in case of padding, the kernel will return early
|
|
||||||
int64_t pad_slot_id) {
|
|
||||||
auto input_type = x.scalar_type();
|
|
||||||
auto weight_type = weight.scalar_type();
|
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
|
||||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
|
||||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
|
||||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
|
||||||
|
|
||||||
TORCH_CHECK(x.is_cuda());
|
|
||||||
TORCH_CHECK(conv_state.is_cuda());
|
|
||||||
TORCH_CHECK(weight.is_cuda());
|
|
||||||
|
|
||||||
const auto sizes = x.sizes();
|
|
||||||
const int batch_size = sizes[0];
|
|
||||||
const int dim = sizes[1];
|
|
||||||
const int seqlen = sizes[2];
|
|
||||||
const int width = weight.size(-1);
|
|
||||||
const int conv_state_len = conv_state.size(2);
|
|
||||||
TORCH_CHECK(conv_state_len >= width - 1);
|
|
||||||
|
|
||||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
|
||||||
CHECK_SHAPE(weight, dim, width);
|
|
||||||
|
|
||||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
|
||||||
|
|
||||||
if (bias_.has_value()) {
|
|
||||||
auto bias = bias_.value();
|
|
||||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
|
||||||
TORCH_CHECK(bias.is_cuda());
|
|
||||||
TORCH_CHECK(bias.stride(-1) == 1);
|
|
||||||
CHECK_SHAPE(bias, dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor out = x;
|
|
||||||
|
|
||||||
ConvParamsBase params;
|
|
||||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
|
||||||
bias_,
|
|
||||||
silu_activation,
|
|
||||||
pad_slot_id);
|
|
||||||
params.conv_state_ptr = conv_state.data_ptr();
|
|
||||||
params.conv_state_len = conv_state_len;
|
|
||||||
// All stride are in elements, not bytes.
|
|
||||||
params.conv_state_batch_stride = conv_state.stride(0);
|
|
||||||
params.conv_state_c_stride = conv_state.stride(1);
|
|
||||||
params.conv_state_l_stride = conv_state.stride(2);
|
|
||||||
|
|
||||||
if (cache_seqlens_.has_value()) {
|
|
||||||
auto cache_seqlens = cache_seqlens_.value();
|
|
||||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
|
||||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
|
||||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
|
||||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
|
||||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
|
||||||
} else {
|
|
||||||
params.cache_seqlens = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (conv_state_indices_.has_value()) {
|
|
||||||
auto conv_state_indices = conv_state_indices_.value();
|
|
||||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
|
||||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
|
||||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
|
||||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
|
||||||
|
|
||||||
int conv_state_entries = conv_state.size(0);
|
|
||||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
|
||||||
|
|
||||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
|
||||||
} else {
|
|
||||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
|
||||||
params.conv_state_indices_ptr = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
||||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
|
||||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
|
||||||
struct Causal_conv1d_fwd_kernel_traits {
|
|
||||||
using input_t = input_t_;
|
|
||||||
using weight_t = weight_t_;
|
|
||||||
static constexpr int kNThreads = kNThreads_;
|
|
||||||
static constexpr int kWidth = kWidth_;
|
|
||||||
static constexpr int kNBytes = sizeof(input_t);
|
|
||||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
|
||||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
|
||||||
static_assert(kWidth <= kNElts);
|
|
||||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
|
||||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
|
||||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
|
||||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
|
||||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
|
||||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
|
||||||
static constexpr int kSmemIOSize = kIsVecLoad
|
|
||||||
? 0
|
|
||||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
|
||||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
|
||||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Ktraits>
|
|
||||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
|
||||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
|
||||||
constexpr int kWidth = Ktraits::kWidth;
|
|
||||||
constexpr int kNThreads = Ktraits::kNThreads;
|
|
||||||
constexpr int kNElts = Ktraits::kNElts;
|
|
||||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
|
||||||
using input_t = typename Ktraits::input_t;
|
|
||||||
using vec_t = typename Ktraits::vec_t;
|
|
||||||
using weight_t = typename Ktraits::weight_t;
|
|
||||||
|
|
||||||
// Shared memory.
|
|
||||||
extern __shared__ char smem_[];
|
|
||||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
|
||||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
|
||||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
|
||||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
|
||||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
|
||||||
|
|
||||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
|
||||||
const int tidx = threadIdx.x;
|
|
||||||
const int batch_id = blockIdx.x;
|
|
||||||
const int channel_id = blockIdx.y;
|
|
||||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
|
||||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
|
||||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
|
||||||
|
|
||||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
|
||||||
+ channel_id * params.x_c_stride;
|
|
||||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
|
||||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
|
||||||
+ channel_id * params.out_c_stride;
|
|
||||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
|
||||||
|
|
||||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
|
||||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
|
||||||
|
|
||||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
|
||||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
|
||||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
|
||||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
|
||||||
if (cache_index == params.pad_slot_id){
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
|
||||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
|
||||||
|
|
||||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
|
||||||
if (tidx == 0) {
|
|
||||||
input_t initial_state[kNElts] = {0};
|
|
||||||
if (has_initial_state) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
|
||||||
}
|
|
||||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
float weight_vals[kWidth];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
|
||||||
|
|
||||||
constexpr int kChunkSize = kNThreads * kNElts;
|
|
||||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
|
||||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
|
||||||
input_t x_vals_load[2 * kNElts] = {0};
|
|
||||||
if constexpr(kIsVecLoad) {
|
|
||||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
|
||||||
} else {
|
|
||||||
__syncthreads();
|
|
||||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
|
||||||
}
|
|
||||||
x += kChunkSize;
|
|
||||||
__syncthreads();
|
|
||||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
|
||||||
// the last elements of the previous chunk.
|
|
||||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
|
||||||
__syncthreads();
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
|
||||||
__syncthreads();
|
|
||||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
|
||||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
|
||||||
|
|
||||||
float x_vals[2 * kNElts];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
|
||||||
|
|
||||||
float out_vals[kNElts];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNElts; ++i) {
|
|
||||||
out_vals[i] = bias_val;
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth; ++w) {
|
|
||||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.silu_activation) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNElts; ++i) {
|
|
||||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
input_t out_vals_store[kNElts];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
|
||||||
if constexpr(kIsVecLoad) {
|
|
||||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
|
||||||
} else {
|
|
||||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
|
||||||
}
|
|
||||||
out += kChunkSize;
|
|
||||||
|
|
||||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
|
||||||
// in case the final state is separated between the last "smem_exchange" and
|
|
||||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
|
||||||
// (which occurs when `final_state_position` is a non-positive index)
|
|
||||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
|
||||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
|
||||||
input_t vals_load[kNElts] = {0};
|
|
||||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
|
||||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
|
||||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < -final_state_position; ++w){
|
|
||||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
|
||||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
|
||||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
|
||||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
|
||||||
conv_states[w] = vals_load[w + final_state_position];
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Final state is stored in the smem_exchange last token slot,
|
|
||||||
// in case seqlen < kWidth, we would need to take the final state from the
|
|
||||||
// initial state which is stored in conv_states
|
|
||||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
|
||||||
// and load it into conv_state accordingly
|
|
||||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
|
||||||
if (conv_states != nullptr && tidx == last_thread) {
|
|
||||||
input_t x_vals_load[kNElts * 2] = {0};
|
|
||||||
// in case we are on the first kWidth tokens
|
|
||||||
if (last_thread == 0 && seqlen < kWidth){
|
|
||||||
// Need to take the initial state
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
|
||||||
const int offset = seqlen - (kWidth - 1);
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){
|
|
||||||
// pad the existing state
|
|
||||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
|
||||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){
|
|
||||||
if (offset + w >= 0)
|
|
||||||
conv_states[w] = x_vals_load[offset + w ];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// in case the final state is in between the threads data
|
|
||||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
|
||||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
|
||||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
|
||||||
// illegal access error on H100.
|
|
||||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
|
||||||
}
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){
|
|
||||||
conv_states[w] = x_vals_load[offset + w ];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
|
||||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
|
||||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
|
||||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
|
||||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
|
||||||
dim3 grid(params.batch, params.dim);
|
|
||||||
|
|
||||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
|
||||||
|
|
||||||
if (kSmemSize >= 48 * 1024) {
|
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
||||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
|
||||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
|
||||||
}
|
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
|
||||||
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
if (params.width == 2) {
|
|
||||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 3) {
|
|
||||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 4) {
|
|
||||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
|
||||||
struct Causal_conv1d_update_kernel_traits {
|
|
||||||
using input_t = input_t_;
|
|
||||||
using weight_t = weight_t_;
|
|
||||||
static constexpr int kNThreads = kNThreads_;
|
|
||||||
static constexpr int kWidth = kWidth_;
|
|
||||||
static constexpr int kNBytes = sizeof(input_t);
|
|
||||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Ktraits, bool kIsCircularBuffer>
|
|
||||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
|
||||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
|
||||||
constexpr int kWidth = Ktraits::kWidth;
|
|
||||||
constexpr int kNThreads = Ktraits::kNThreads;
|
|
||||||
using input_t = typename Ktraits::input_t;
|
|
||||||
using weight_t = typename Ktraits::weight_t;
|
|
||||||
|
|
||||||
const int tidx = threadIdx.x;
|
|
||||||
const int batch_id = blockIdx.x;
|
|
||||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
|
||||||
if (channel_id >= params.dim) return;
|
|
||||||
|
|
||||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
|
||||||
+ channel_id * params.x_c_stride;
|
|
||||||
|
|
||||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
|
||||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
|
||||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
|
||||||
? batch_id
|
|
||||||
: params.conv_state_indices_ptr[batch_id];
|
|
||||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
|
||||||
if (conv_state_batch_coord == params.pad_slot_id){
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
|
||||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
|
||||||
+ channel_id * params.conv_state_c_stride;
|
|
||||||
|
|
||||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
|
||||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
|
||||||
+ channel_id * params.out_c_stride;
|
|
||||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
|
||||||
|
|
||||||
int state_len = params.conv_state_len;
|
|
||||||
int advance_len = params.seqlen;
|
|
||||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
|
||||||
int update_idx = cache_seqlen - (kWidth - 1);
|
|
||||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
|
||||||
|
|
||||||
float weight_vals[kWidth] = {0};
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
|
||||||
|
|
||||||
float x_vals[kWidth] = {0};
|
|
||||||
if constexpr (!kIsCircularBuffer) {
|
|
||||||
#pragma unroll 2
|
|
||||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
|
||||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth - 1; ++i) {
|
|
||||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
|
||||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
|
||||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
|
||||||
}
|
|
||||||
x_vals[i] = float(state_val);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
|
||||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
|
||||||
x_vals[i] = float(state_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll 2
|
|
||||||
for (int i = 0; i < params.seqlen; ++i) {
|
|
||||||
input_t x_val = x[i * params.x_l_stride];
|
|
||||||
if constexpr (!kIsCircularBuffer) {
|
|
||||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
|
||||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
|
||||||
++update_idx;
|
|
||||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
|
||||||
}
|
|
||||||
x_vals[kWidth - 1] = float(x_val);
|
|
||||||
float out_val = bias_val;
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
|
||||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
|
||||||
out[i * params.out_l_stride] = input_t(out_val);
|
|
||||||
// Shift the input buffer by 1
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
|
||||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
|
||||||
auto kernel = params.cache_seqlens == nullptr
|
|
||||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
|
||||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
|
||||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
if (params.width == 2) {
|
|
||||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 3) {
|
|
||||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 4) {
|
|
||||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
@ -1,159 +0,0 @@
|
|||||||
/******************************************************************************
|
|
||||||
* Copyright (c) 2024, Tri Dao.
|
|
||||||
******************************************************************************/
|
|
||||||
// clang-format off
|
|
||||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
struct ConvParamsBase {
|
|
||||||
using index_t = uint32_t;
|
|
||||||
|
|
||||||
int batch, dim, seqlen, width;
|
|
||||||
int64_t pad_slot_id;
|
|
||||||
bool silu_activation;
|
|
||||||
|
|
||||||
index_t x_batch_stride;
|
|
||||||
index_t x_c_stride;
|
|
||||||
index_t x_l_stride;
|
|
||||||
index_t weight_c_stride;
|
|
||||||
index_t weight_width_stride;
|
|
||||||
index_t out_batch_stride;
|
|
||||||
index_t out_c_stride;
|
|
||||||
index_t out_l_stride;
|
|
||||||
|
|
||||||
int conv_state_len;
|
|
||||||
index_t conv_state_batch_stride;
|
|
||||||
index_t conv_state_c_stride;
|
|
||||||
index_t conv_state_l_stride;
|
|
||||||
|
|
||||||
// Common data pointers.
|
|
||||||
void *__restrict__ x_ptr;
|
|
||||||
void *__restrict__ weight_ptr;
|
|
||||||
void *__restrict__ bias_ptr;
|
|
||||||
void *__restrict__ out_ptr;
|
|
||||||
|
|
||||||
void *__restrict__ conv_state_ptr;
|
|
||||||
void *__restrict__ query_start_loc_ptr;
|
|
||||||
void *__restrict__ has_initial_state_ptr;
|
|
||||||
void *__restrict__ cache_indices_ptr;
|
|
||||||
int32_t *__restrict__ cache_seqlens;
|
|
||||||
|
|
||||||
// For the continuous batching case. Makes it so that the mamba state for
|
|
||||||
// the current batch doesn't need to be a contiguous tensor.
|
|
||||||
int32_t *__restrict__ conv_state_indices_ptr;
|
|
||||||
|
|
||||||
void *__restrict__ seq_idx_ptr;
|
|
||||||
|
|
||||||
// No __restrict__ since initial_states could be the same as final_states.
|
|
||||||
void * initial_states_ptr;
|
|
||||||
index_t initial_states_batch_stride;
|
|
||||||
index_t initial_states_l_stride;
|
|
||||||
index_t initial_states_c_stride;
|
|
||||||
|
|
||||||
void * final_states_ptr;
|
|
||||||
index_t final_states_batch_stride;
|
|
||||||
index_t final_states_l_stride;
|
|
||||||
index_t final_states_c_stride;
|
|
||||||
|
|
||||||
void * conv_states_ptr;
|
|
||||||
index_t conv_states_batch_stride;
|
|
||||||
index_t conv_states_l_stride;
|
|
||||||
index_t conv_states_c_stride;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__device__ inline T shuffle_xor(T val, int offset) {
|
|
||||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
|
||||||
{
|
|
||||||
return std::max(ilist);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
constexpr T constexpr_min(T a, T b) {
|
|
||||||
return std::min(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
#include <hip/hip_bf16.h>
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__device__ inline T shuffle_xor(T val, int offset) {
|
|
||||||
return __shfl_xor(val, offset);
|
|
||||||
}
|
|
||||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
|
||||||
{
|
|
||||||
return *std::max_element(ilist.begin(), ilist.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
constexpr T constexpr_min(T a, T b) {
|
|
||||||
return a < b ? a : b;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<int BYTES> struct BytesToType {};
|
|
||||||
|
|
||||||
template<> struct BytesToType<16> {
|
|
||||||
using Type = uint4;
|
|
||||||
static_assert(sizeof(Type) == 16);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<8> {
|
|
||||||
using Type = uint64_t;
|
|
||||||
static_assert(sizeof(Type) == 8);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<4> {
|
|
||||||
using Type = uint32_t;
|
|
||||||
static_assert(sizeof(Type) == 4);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<2> {
|
|
||||||
using Type = uint16_t;
|
|
||||||
static_assert(sizeof(Type) == 2);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<1> {
|
|
||||||
using Type = uint8_t;
|
|
||||||
static_assert(sizeof(Type) == 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct SumOp {
|
|
||||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<int THREADS>
|
|
||||||
struct Allreduce {
|
|
||||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
|
||||||
template<typename T, typename Operator>
|
|
||||||
static __device__ inline T run(T x, Operator &op) {
|
|
||||||
constexpr int OFFSET = THREADS / 2;
|
|
||||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
|
||||||
return Allreduce<OFFSET>::run(x, op);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct Allreduce<2> {
|
|
||||||
template<typename T, typename Operator>
|
|
||||||
static __device__ inline T run(T x, Operator &op) {
|
|
||||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
// Inspired by
|
|
||||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
|
||||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
|
||||||
// clang-format off
|
|
||||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
/// @param COND - a boolean expression to switch by
|
|
||||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
|
||||||
/// @param ... - code to execute for true and false
|
|
||||||
///
|
|
||||||
/// Usage:
|
|
||||||
/// ```
|
|
||||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
|
||||||
/// some_function<BoolConst>(...);
|
|
||||||
/// });
|
|
||||||
/// ```
|
|
||||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
|
||||||
[&] { \
|
|
||||||
if (COND) { \
|
|
||||||
static constexpr bool CONST_NAME = true; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
} else { \
|
|
||||||
static constexpr bool CONST_NAME = false; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
} \
|
|
||||||
}()
|
|
||||||
16
csrc/ops.h
16
csrc/ops.h
@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
|||||||
const std::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||||
|
|
||||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
|
||||||
const at::Tensor& weight,
|
|
||||||
const std::optional<at::Tensor>& bias_,
|
|
||||||
bool silu_activation,
|
|
||||||
const std::optional<at::Tensor>& cache_seqlens_,
|
|
||||||
const std::optional<at::Tensor>& conv_state_indices_,
|
|
||||||
int64_t pad_slot_id);
|
|
||||||
|
|
||||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
|
||||||
const std::optional<at::Tensor>& bias_,
|
|
||||||
const std::optional<at::Tensor>& conv_states,
|
|
||||||
const std::optional<at::Tensor>& query_start_loc,
|
|
||||||
const std::optional<at::Tensor>& cache_indices,
|
|
||||||
const std::optional<at::Tensor>& has_initial_state,
|
|
||||||
bool silu_activation, int64_t pad_slot_id);
|
|
||||||
|
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||||
torch::Tensor& rank_data, int64_t rank,
|
torch::Tensor& rank_data, int64_t rank,
|
||||||
|
|||||||
@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int pad_slot_id) -> ()");
|
"int pad_slot_id) -> ()");
|
||||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"causal_conv1d_update(Tensor! x,"
|
|
||||||
"Tensor! conv_state,"
|
|
||||||
"Tensor! weight,"
|
|
||||||
"Tensor? bias_,"
|
|
||||||
"bool silu_activation,"
|
|
||||||
"Tensor? cache_seqlens_,"
|
|
||||||
"Tensor? conv_state_indices,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
|
||||||
"Tensor? bias_,"
|
|
||||||
"Tensor!? conv_states,"
|
|
||||||
"Tensor? query_start_loc,"
|
|
||||||
"Tensor? cache_indices,"
|
|
||||||
"Tensor? has_initial_state,"
|
|
||||||
"bool silu_activation,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -6,9 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
|
||||||
from vllm import _custom_ops as ops # noqa: F401
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
|||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
bias = bias.contiguous() if bias is not None else None
|
bias = bias.contiguous() if bias is not None else None
|
||||||
|
|
||||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
|
||||||
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
|
|
||||||
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
|
||||||
@pytest.mark.parametrize("silu_activation", [True])
|
|
||||||
@pytest.mark.parametrize("has_bias", [True])
|
|
||||||
@pytest.mark.parametrize("has_initial_state", [True, False])
|
|
||||||
@pytest.mark.parametrize("width", [4])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
|
||||||
@pytest.mark.parametrize('dim', [64])
|
|
||||||
@pytest.mark.parametrize('batch', [1])
|
|
||||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
|
||||||
has_initial_state, itype):
|
|
||||||
device = "cuda"
|
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
|
||||||
if itype == torch.bfloat16:
|
|
||||||
rtol, atol = 1e-2, 5e-2
|
|
||||||
# set seed
|
|
||||||
current_platform.seed_everything(0)
|
|
||||||
x = torch.randn(batch, dim, seqlen, device=device,
|
|
||||||
dtype=itype).contiguous()
|
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
|
||||||
if has_initial_state:
|
|
||||||
initial_states = torch.randn(batch,
|
|
||||||
dim,
|
|
||||||
width - 1,
|
|
||||||
device=device,
|
|
||||||
dtype=itype)
|
|
||||||
has_initial_state_tensor = torch.ones(batch,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=x.device)
|
|
||||||
else:
|
|
||||||
initial_states = None
|
|
||||||
has_initial_state_tensor = None
|
|
||||||
x_ref = x.clone()
|
|
||||||
weight_ref = weight.clone()
|
|
||||||
bias_ref = bias.clone() if bias is not None else None
|
|
||||||
initial_states_ref = initial_states.clone(
|
|
||||||
) if initial_states is not None else None
|
|
||||||
activation = None if not silu_activation else "silu"
|
|
||||||
out = causal_conv1d_fn(x,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
activation=activation,
|
|
||||||
conv_states=initial_states,
|
|
||||||
has_initial_state=has_initial_state_tensor)
|
|
||||||
out_ref, final_states_ref = causal_conv1d_ref(
|
|
||||||
x_ref,
|
|
||||||
weight_ref,
|
|
||||||
bias_ref,
|
|
||||||
initial_states=initial_states_ref,
|
|
||||||
return_final_states=True,
|
|
||||||
activation=activation)
|
|
||||||
if has_initial_state:
|
|
||||||
assert initial_states is not None and final_states_ref is not None
|
|
||||||
assert torch.allclose(initial_states,
|
|
||||||
final_states_ref,
|
|
||||||
rtol=rtol,
|
|
||||||
atol=atol)
|
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
causal_conv1d_opcheck_fn(x,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
activation=activation,
|
|
||||||
conv_states=initial_states,
|
|
||||||
has_initial_state=has_initial_state_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||||
@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
|||||||
assert torch.equal(conv_state, conv_state_ref)
|
assert torch.equal(conv_state, conv_state_ref)
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
opcheck(torch.ops._C.causal_conv1d_update,
|
|
||||||
(x, conv_state, weight, bias, activation
|
|
||||||
in ["silu", "swish"], None, None, PAD_SLOT_ID))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype",
|
@pytest.mark.parametrize("itype",
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||||
@pytest.mark.parametrize("has_bias", [False, True])
|
@pytest.mark.parametrize("has_bias", [False, True])
|
||||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
@pytest.mark.parametrize("width", [3, 4])
|
||||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||||
# tests correctness in case subset of the sequences are padded
|
# tests correctness in case subset of the sequences are padded
|
||||||
@pytest.mark.parametrize("with_padding", [True, False])
|
@pytest.mark.parametrize("with_padding", [True, False])
|
||||||
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
@pytest.mark.parametrize("batch_size", [3])
|
||||||
seqlen, has_bias,
|
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
||||||
|
width, seqlen, has_bias,
|
||||||
silu_activation, itype):
|
silu_activation, itype):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||||
@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
|||||||
# set seed
|
# set seed
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
batch_size = 3
|
|
||||||
padding = 5 if with_padding else 0
|
padding = 5 if with_padding else 0
|
||||||
padded_batch_size = batch_size + padding
|
padded_batch_size = batch_size + padding
|
||||||
|
# total_entries = number of cache line
|
||||||
total_entries = 10 * batch_size
|
total_entries = 10 * batch_size
|
||||||
|
|
||||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||||
|
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
|
||||||
|
dtype=itype).transpose(1, 2)
|
||||||
|
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
|
|
||||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||||
@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
|||||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||||
],
|
],
|
||||||
dim=0)
|
dim=0)
|
||||||
|
|
||||||
|
# conv_state will be (cache_lines, dim, state_len)
|
||||||
|
# with contiguous along dim-axis
|
||||||
conv_state = torch.randn(total_entries,
|
conv_state = torch.randn(total_entries,
|
||||||
dim,
|
|
||||||
width - 1,
|
width - 1,
|
||||||
|
dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=itype)
|
dtype=itype).transpose(1, 2)
|
||||||
|
|
||||||
conv_state_for_padding_test = conv_state.clone()
|
conv_state_for_padding_test = conv_state.clone()
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||||
activation = None if not silu_activation else "silu"
|
activation = None if not silu_activation else "silu"
|
||||||
|
|
||||||
out = causal_conv1d_update(x,
|
out = causal_conv1d_update(x,
|
||||||
conv_state,
|
conv_state,
|
||||||
weight,
|
weight,
|
||||||
@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
|||||||
activation=activation)
|
activation=activation)
|
||||||
|
|
||||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
|
||||||
assert torch.equal(conv_state[unused_states_bool],
|
assert torch.equal(conv_state[unused_states_bool],
|
||||||
conv_state_for_padding_test[unused_states_bool])
|
conv_state_for_padding_test[unused_states_bool])
|
||||||
|
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||||
opcheck(torch.ops._C.causal_conv1d_update,
|
|
||||||
(x, conv_state, weight, bias, activation
|
|
||||||
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("silu_activation", [True])
|
@pytest.mark.parametrize("silu_activation", [True])
|
||||||
@pytest.mark.parametrize("has_bias", [True])
|
@pytest.mark.parametrize("has_bias", [True])
|
||||||
@pytest.mark.parametrize("width", [4])
|
@pytest.mark.parametrize("width", [4])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
|
||||||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
|
|
||||||
@pytest.mark.parametrize('dim', [64, 4096])
|
@pytest.mark.parametrize('dim', [64, 4096])
|
||||||
# tests correctness in case subset of the sequences are padded
|
|
||||||
@pytest.mark.parametrize('with_padding', [True, False])
|
@pytest.mark.parametrize('with_padding', [True, False])
|
||||||
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
@pytest.mark.parametrize('batch', [4, 10])
|
||||||
silu_activation, itype):
|
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
|
||||||
|
has_bias, silu_activation, itype):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||||
@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
# set seed
|
# set seed
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
seqlens = []
|
seqlens = []
|
||||||
batch_size = 4
|
batch_size = batch
|
||||||
if seqlen < 10:
|
|
||||||
batch_size = 1
|
|
||||||
padding = 3 if with_padding else 0
|
padding = 3 if with_padding else 0
|
||||||
padded_batch_size = batch_size + padding
|
padded_batch_size = batch_size + padding
|
||||||
nsplits = padded_batch_size - 1
|
nsplits = padded_batch_size - 1
|
||||||
|
|
||||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||||
|
|
||||||
seqlens.append(
|
seqlens.append(
|
||||||
torch.diff(
|
torch.diff(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||||
dim=0)
|
dim=0)
|
||||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
|
x = rearrange(
|
||||||
dtype=itype)[:, 4096:4096 + dim, :]
|
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||||
|
"b s d -> b d s")[:, 4096:4096 + dim, :]
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||||
|
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
weight_ref = weight.clone()
|
weight_ref = weight.clone()
|
||||||
bias_ref = bias.clone() if bias is not None else None
|
bias_ref = bias.clone() if bias is not None else None
|
||||||
activation = None if not silu_activation else "silu"
|
activation = None if not silu_activation else "silu"
|
||||||
final_states = torch.randn(total_entries,
|
final_states = torch.randn(total_entries,
|
||||||
dim,
|
|
||||||
width - 1,
|
width - 1,
|
||||||
|
dim,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=x.dtype)
|
dtype=x.dtype).transpose(1, 2)
|
||||||
final_states_ref = final_states.clone()
|
final_states_ref = final_states.clone()
|
||||||
has_initial_states = torch.randint(0,
|
has_initial_states = torch.randint(0,
|
||||||
2, (cumsum.shape[0] - 1, ),
|
2, (cumsum.shape[0] - 1, ),
|
||||||
@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
out = causal_conv1d_fn(x.squeeze(0),
|
||||||
|
weight,
|
||||||
|
bias=bias,
|
||||||
|
conv_states=final_states,
|
||||||
|
query_start_loc=cumsum.cuda(),
|
||||||
|
cache_indices=padded_state_indices,
|
||||||
|
has_initial_state=has_initial_states,
|
||||||
|
activation=activation,
|
||||||
|
pad_slot_id=PAD_SLOT_ID)
|
||||||
|
|
||||||
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
|
||||||
padded_state_indices, has_initial_states,
|
|
||||||
final_states, activation, PAD_SLOT_ID)
|
|
||||||
out_ref = []
|
out_ref = []
|
||||||
out_ref_b = []
|
out_ref_b = []
|
||||||
|
|
||||||
@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||||
|
|
||||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
|
||||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
|
||||||
assert torch.allclose(final_states[state_indices],
|
assert torch.allclose(final_states[state_indices],
|
||||||
final_states_ref[state_indices],
|
final_states_ref[state_indices],
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
|
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||||
padded_state_indices, has_initial_states,
|
|
||||||
final_states, activation)
|
|
||||||
|
|||||||
@ -6,11 +6,11 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
|
||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
|
_query_start_loc_to_chunk_indices_offsets)
|
||||||
|
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
|
|
||||||
|
|||||||
@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
|
|||||||
|
|
||||||
|
|
||||||
# mamba
|
# mamba
|
||||||
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
|
||||||
bias_: Optional[torch.Tensor],
|
|
||||||
conv_states: Optional[torch.Tensor],
|
|
||||||
query_start_loc: Optional[torch.Tensor],
|
|
||||||
cache_indices: Optional[torch.Tensor],
|
|
||||||
has_initial_state: Optional[torch.Tensor],
|
|
||||||
silu_activation: bool, pad_slot_id: int):
|
|
||||||
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
|
|
||||||
query_start_loc, cache_indices,
|
|
||||||
has_initial_state, silu_activation,
|
|
||||||
pad_slot_id)
|
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
|
||||||
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
|
||||||
silu_activation: bool,
|
|
||||||
cache_seqlens: Optional[torch.Tensor],
|
|
||||||
conv_state_indices: Optional[torch.Tensor],
|
|
||||||
pad_slot_id: int):
|
|
||||||
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
|
||||||
silu_activation, cache_seqlens,
|
|
||||||
conv_state_indices, pad_slot_id)
|
|
||||||
|
|
||||||
|
|
||||||
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||||
B: torch.Tensor, C: torch.Tensor,
|
B: torch.Tensor, C: torch.Tensor,
|
||||||
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
||||||
|
|||||||
@ -1,14 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
from vllm.attention.backends.placeholder_attn import (
|
||||||
PlaceholderAttentionMetadata)
|
PlaceholderAttentionMetadata)
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
|
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -21,6 +25,29 @@ class Mamba2Metadata:
|
|||||||
seq_idx: torch.Tensor
|
seq_idx: torch.Tensor
|
||||||
chunk_indices: torch.Tensor
|
chunk_indices: torch.Tensor
|
||||||
chunk_offsets: torch.Tensor
|
chunk_offsets: torch.Tensor
|
||||||
|
"""
|
||||||
|
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||||
|
to handle a request in parallel, two supporting tensors are used
|
||||||
|
(batch_ptr, token_chunk_offset_ptr)
|
||||||
|
BLOCK_M = the # tokens to be handled by a Triton program
|
||||||
|
(can be customized for different hardware)
|
||||||
|
|
||||||
|
nums_dict:
|
||||||
|
tracks the data associated with a given value of BLOCK_M
|
||||||
|
BLOCK_M = #tokens handled by a Triton program
|
||||||
|
cu_seqlen: total tokens per batch
|
||||||
|
(used as flag to update other data at each new input)
|
||||||
|
batch_ptr: tracks batch-id handled by the Triton program
|
||||||
|
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
|
||||||
|
(Triton implementation of causal_conv1d handles parallelism in 3-axes
|
||||||
|
- feature-axis
|
||||||
|
- batch-axis
|
||||||
|
- sequence-axis)
|
||||||
|
"""
|
||||||
|
nums_dict: Optional[dict] = None
|
||||||
|
cu_seqlen: Optional[int] = None
|
||||||
|
batch_ptr: Optional[torch.tensor] = None
|
||||||
|
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||||
|
|
||||||
|
|
||||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||||
@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
|||||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
|
||||||
chunk_size: int,
|
|
||||||
total_seqlens: int):
|
|
||||||
|
|
||||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
|
||||||
|
|
||||||
# outputs will have length expansion of chunks that do not divide
|
|
||||||
# chunk_size
|
|
||||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
|
||||||
> 0).sum()
|
|
||||||
chunk_indices = torch.arange(N,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query_start_loc.device)
|
|
||||||
chunk_offsets = torch.zeros((N, ),
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query_start_loc.device)
|
|
||||||
|
|
||||||
p = 0 # num of insertions
|
|
||||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
|
||||||
|
|
||||||
# if does not divide chunk_size, then there is one chunk insertion
|
|
||||||
p += (s % chunk_size > 0)
|
|
||||||
|
|
||||||
# get the dimensions
|
|
||||||
# - the + 1 for _e is to shift the boundary by one chunk
|
|
||||||
# - this shifting is not needed if chunk_size divides e
|
|
||||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
|
||||||
> 0)
|
|
||||||
|
|
||||||
# adjust inidces and offsets
|
|
||||||
chunk_indices[_s:_e] -= p
|
|
||||||
chunk_offsets[_s] = s % chunk_size
|
|
||||||
|
|
||||||
return chunk_indices, chunk_offsets
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_mamba2_metadata(
|
def prepare_mamba2_metadata(
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
mamba2_metadata=None,
|
||||||
) -> Mamba2Metadata:
|
) -> Mamba2Metadata:
|
||||||
|
|
||||||
# compute number of prefill and decode requests
|
# compute number of prefill and decode requests
|
||||||
@ -96,12 +88,12 @@ def prepare_mamba2_metadata(
|
|||||||
attn_metadata_instances = get_platform_metadata_classes()
|
attn_metadata_instances = get_platform_metadata_classes()
|
||||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||||
and attn_metadata.context_lens_tensor is not None):
|
and attn_metadata.context_lens_tensor is not None):
|
||||||
has_initial_states = \
|
# precompute flag to avoid device syncs later in mamba2 layer
|
||||||
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
# forwards
|
||||||
# precompute flag to avoid device syncs in mamba2 layer forwards
|
|
||||||
# prep is only needed for mamba2 ssd prefill processing
|
# prep is only needed for mamba2 ssd prefill processing
|
||||||
prep_initial_states = torch.any(has_initial_states).item()
|
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||||
|
prep_initial_states = torch.any(
|
||||||
|
has_initial_states[:num_prefills]).item()
|
||||||
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||||
seq_idx = torch.repeat_interleave(torch.arange(
|
seq_idx = torch.repeat_interleave(torch.arange(
|
||||||
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
||||||
@ -117,9 +109,78 @@ def prepare_mamba2_metadata(
|
|||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
query_start_loc, chunk_size, num_prefill_tokens)
|
query_start_loc, chunk_size, num_prefill_tokens)
|
||||||
|
|
||||||
|
if mamba2_metadata is not None:
|
||||||
|
mamba2_metadata.has_initial_states = has_initial_states
|
||||||
|
mamba2_metadata.prep_initial_states = prep_initial_states
|
||||||
|
mamba2_metadata.chunk_size = chunk_size
|
||||||
|
mamba2_metadata.seq_idx = seq_idx
|
||||||
|
mamba2_metadata.chunk_indices = chunk_indices
|
||||||
|
mamba2_metadata.chunk_offsets = chunk_offsets
|
||||||
|
# We use 1 reset flag:
|
||||||
|
# * mamba2_metadata.cu_seqlen is None
|
||||||
|
# update config specific to (each input)
|
||||||
|
# (become available at first layer, e.g. conv_weights)
|
||||||
|
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
|
||||||
|
|
||||||
|
return mamba2_metadata
|
||||||
return Mamba2Metadata(has_initial_states=has_initial_states,
|
return Mamba2Metadata(has_initial_states=has_initial_states,
|
||||||
prep_initial_states=prep_initial_states,
|
prep_initial_states=prep_initial_states,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets)
|
chunk_offsets=chunk_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||||
|
mamba2_metadata: Union[Mamba2Metadata,
|
||||||
|
Mamba2AttentionMetadata]):
|
||||||
|
"""
|
||||||
|
this is triggered upon handling a new input at the first layer
|
||||||
|
"""
|
||||||
|
dim, cu_seqlen = x.shape
|
||||||
|
mamba2_metadata.cu_seqlen = cu_seqlen
|
||||||
|
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||||
|
nums_dict = {} # type: ignore
|
||||||
|
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||||
|
nums = -(-seqlens // BLOCK_M)
|
||||||
|
nums_dict[BLOCK_M] = {}
|
||||||
|
nums_dict[BLOCK_M]['nums'] = nums
|
||||||
|
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||||
|
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||||
|
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||||
|
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||||
|
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||||
|
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||||
|
offsetlist = [] # type: ignore
|
||||||
|
for idx, num in enumerate(nums):
|
||||||
|
offsetlist.extend(range(num))
|
||||||
|
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||||
|
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||||
|
|
||||||
|
if mamba2_metadata.batch_ptr is None:
|
||||||
|
# Update default value after class definition
|
||||||
|
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
|
||||||
|
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device='cuda')
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr = torch.full(
|
||||||
|
(MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device='cuda')
|
||||||
|
else:
|
||||||
|
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||||
|
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
|
||||||
|
PAD_SLOT_ID)
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
|
||||||
|
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
|
||||||
|
0:mlist_len].copy_(offsetlist)
|
||||||
|
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
|
||||||
|
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
|
||||||
|
mamba2_metadata.nums_dict = nums_dict
|
||||||
|
return mamba2_metadata
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class MambaMixer(CustomOp):
|
|||||||
hidden_states = causal_conv1d_fn(
|
hidden_states = causal_conv1d_fn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=mamba_cache_params.conv_state,
|
conv_states=mamba_cache_params.conv_state,
|
||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||||
|
|||||||
@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
|
update_metadata)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
|
|||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
|
mamba2_metadata = attn_metadata
|
||||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
conv_state = self_kv_cache[0]
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states
|
has_initial_states_p = attn_metadata.has_initial_states
|
||||||
@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
|
|||||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
# Separate prefill and decode by splitting varlen input
|
# Separate prefill and decode by splitting varlen input
|
||||||
# Split along token dimension
|
# Split along token dimension
|
||||||
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||||
hidden_states_B_C,
|
hidden_states_B_C,
|
||||||
@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
|
|||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
# - "cache_indices" updates the conv_state cache in positions
|
# - "cache_indices" updates the conv_state cache in positions
|
||||||
# pointed to by "state_indices_tensor"
|
# pointed to by "state_indices_tensor"
|
||||||
|
x = hidden_states_B_C_p.transpose(
|
||||||
|
0, 1) # this is the form that causal-conv see
|
||||||
|
if mamba2_metadata.cu_seqlen is None:
|
||||||
|
mamba2_metadata = update_metadata(
|
||||||
|
x, attn_metadata.query_start_loc, mamba2_metadata)
|
||||||
hidden_states_B_C_p = causal_conv1d_fn(
|
hidden_states_B_C_p = causal_conv1d_fn(
|
||||||
hidden_states_B_C_p.transpose(0, 1),
|
x,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
|
|||||||
query_start_loc=query_start_loc_p).transpose(
|
query_start_loc=query_start_loc_p).transpose(
|
||||||
0, 1)[:num_prefill_tokens]
|
0, 1)[:num_prefill_tokens]
|
||||||
|
|
||||||
# TODO: Why is this needed?
|
|
||||||
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
|
|
||||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||||
hidden_states_B_C_p)
|
hidden_states_B_C_p)
|
||||||
|
|
||||||
@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
|
|||||||
# - heads and n_groups are TP-ed
|
# - heads and n_groups are TP-ed
|
||||||
conv_dim = (self.intermediate_size +
|
conv_dim = (self.intermediate_size +
|
||||||
2 * n_groups * self.ssm_state_size)
|
2 * n_groups * self.ssm_state_size)
|
||||||
|
# contiguous along 'dim' axis
|
||||||
conv_state_shape = (
|
conv_state_shape = (
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.conv_kernel_size - 1,
|
self.conv_kernel_size - 1,
|
||||||
|
divide(conv_dim, world_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
|
|||||||
@ -4,31 +4,394 @@
|
|||||||
# Copyright (c) 2024, Tri Dao.
|
# Copyright (c) 2024, Tri Dao.
|
||||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_fn(x: torch.Tensor,
|
@triton.jit()
|
||||||
weight: torch.Tensor,
|
def _causal_conv1d_fwd_kernel( # continuous batching
|
||||||
bias: Optional[torch.Tensor] = None,
|
# Pointers to matrices
|
||||||
query_start_loc: Optional[torch.Tensor] = None,
|
x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
|
||||||
cache_indices: Optional[torch.Tensor] = None,
|
w_ptr, # (dim, width)
|
||||||
has_initial_state: Optional[torch.Tensor] = None,
|
bias_ptr,
|
||||||
conv_states: Optional[torch.Tensor] = None,
|
initial_states_ptr, # conv_states_ptr
|
||||||
activation: Optional[str] = "silu",
|
cache_indices_ptr, # conv_state_indices_ptr
|
||||||
pad_slot_id: int = PAD_SLOT_ID):
|
has_initial_states_ptr,
|
||||||
"""
|
query_start_loc_ptr,
|
||||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
batch_ptr,
|
||||||
|
token_chunk_offset_ptr,
|
||||||
|
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
||||||
|
# Matrix dimensions
|
||||||
|
batch: tl.int32, # actually padded_batch
|
||||||
|
dim: tl.constexpr,
|
||||||
|
seqlen: tl.int32, # cu_seqlen
|
||||||
|
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||||
|
# Strides
|
||||||
|
stride_x_seq: tl.constexpr, # stride to get to next sequence,
|
||||||
|
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
|
||||||
|
stride_x_token: tl.
|
||||||
|
constexpr, # stride to get to next token (same feature-index, same sequence-index)
|
||||||
|
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
|
||||||
|
stride_w_width: tl.constexpr, # stride to get to next width-axis value
|
||||||
|
stride_istate_seq: tl.constexpr,
|
||||||
|
stride_istate_dim: tl.constexpr,
|
||||||
|
stride_istate_token: tl.constexpr,
|
||||||
|
stride_o_seq: tl.constexpr,
|
||||||
|
stride_o_dim: tl.constexpr,
|
||||||
|
stride_o_token: tl.constexpr,
|
||||||
|
# others
|
||||||
|
pad_slot_id: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
HAS_BIAS: tl.constexpr,
|
||||||
|
KERNEL_WIDTH: tl.constexpr,
|
||||||
|
SILU_ACTIVATION: tl.constexpr,
|
||||||
|
HAS_INITIAL_STATES: tl.constexpr,
|
||||||
|
HAS_CACHE: tl.constexpr,
|
||||||
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||||
|
USE_PAD_SLOT: tl.constexpr,
|
||||||
|
NP2_STATELEN: tl.constexpr,
|
||||||
|
DECODE_SEQLEN: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
conv_states_ptr = initial_states_ptr
|
||||||
|
conv_state_indices_ptr = cache_indices_ptr
|
||||||
|
stride_conv_state_seq = stride_istate_seq
|
||||||
|
stride_conv_state_dim = stride_istate_dim
|
||||||
|
stride_conv_state_tok = stride_istate_token
|
||||||
|
state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value
|
||||||
|
|
||||||
|
# one program handles one chunk in a single sequence
|
||||||
|
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||||
|
|
||||||
|
# single-sequence id
|
||||||
|
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
||||||
|
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
||||||
|
|
||||||
|
# BLOCK_N elements along the feature-dimension (channel)
|
||||||
|
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
if idx_seq == pad_slot_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
|
||||||
|
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
|
||||||
|
# find the actual sequence length
|
||||||
|
seqlen = sequence_end_index - sequence_start_index
|
||||||
|
|
||||||
|
token_offset = BLOCK_M * chunk_offset
|
||||||
|
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||||
|
|
||||||
|
# base of the sequence
|
||||||
|
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
|
||||||
|
|
||||||
|
if IS_CONTINUOUS_BATCHING:
|
||||||
|
# cache_idx
|
||||||
|
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
|
||||||
|
else:
|
||||||
|
# cache_idx
|
||||||
|
conv_state_batch_coord = idx_seq
|
||||||
|
if USE_PAD_SLOT: # noqa
|
||||||
|
if conv_state_batch_coord == pad_slot_id:
|
||||||
|
# not processing as this is not the actual sequence
|
||||||
|
return
|
||||||
|
conv_states_base = (conv_states_ptr +
|
||||||
|
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
|
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||||
|
|
||||||
|
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||||
|
|
||||||
|
# Does 2 things:
|
||||||
|
# 1. READ prior-block init-state data - [done by every Triton programs]
|
||||||
|
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
|
||||||
|
if chunk_offset == 0:
|
||||||
|
# read from conv_states
|
||||||
|
load_init_state = False
|
||||||
|
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
|
||||||
|
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
|
||||||
|
tl.int1)
|
||||||
|
if load_init_state:
|
||||||
|
# load from conv_states
|
||||||
|
prior_tokens = conv_states_base + (state_len -
|
||||||
|
1) * stride_conv_state_tok
|
||||||
|
mask_w = idx_feats < dim
|
||||||
|
if KERNEL_WIDTH == 2:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
if KERNEL_WIDTH == 3:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
if KERNEL_WIDTH == 4:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
if KERNEL_WIDTH == 5:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
else:
|
||||||
|
# prior-tokens are zeros
|
||||||
|
if KERNEL_WIDTH >= 2: # STRATEGY1
|
||||||
|
# first chunk and does not have prior-token, so just set to 0
|
||||||
|
col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||||
|
if KERNEL_WIDTH >= 3: # STRATEGY1
|
||||||
|
col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||||
|
if KERNEL_WIDTH >= 4: # STRATEGY1
|
||||||
|
col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||||
|
if KERNEL_WIDTH >= 5: # STRATEGY1
|
||||||
|
col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
# STEP 2:
|
||||||
|
# here prepare data for updating conv_state
|
||||||
|
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
|
||||||
|
# just read from 'x'
|
||||||
|
# copy 'x' data to conv_state
|
||||||
|
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
|
||||||
|
idx_tokens_last = (seqlen - state_len) + tl.arange(
|
||||||
|
0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
x_ptrs = x_ptr + (
|
||||||
|
(sequence_start_index + idx_tokens_last) *
|
||||||
|
stride_x_token)[:, None] + (
|
||||||
|
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
|
||||||
|
mask_x = ((idx_tokens_last >= 0)[:, None] &
|
||||||
|
(idx_tokens_last < seqlen)[:, None] &
|
||||||
|
(idx_feats < dim)[None, :]
|
||||||
|
) # token-index # token-index # feature-index
|
||||||
|
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||||
|
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||||
|
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
conv_states_ptrs_target = conv_states_base[None, :] + (
|
||||||
|
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||||
|
|
||||||
|
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||||
|
< dim)[None, :]
|
||||||
|
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||||
|
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if load_init_state:
|
||||||
|
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
|
||||||
|
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
|
||||||
|
conv_states_ptrs_source = (
|
||||||
|
conv_states_ptr +
|
||||||
|
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
|
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||||
|
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
|
||||||
|
None]
|
||||||
|
) # [BLOCK_M, BLOCK_N]
|
||||||
|
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||||
|
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
||||||
|
& (idx_feats < dim)[None, :])
|
||||||
|
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
|
||||||
|
|
||||||
|
VAL = state_len - seqlen
|
||||||
|
|
||||||
|
x_ptrs = x_base[None, :] + (
|
||||||
|
(idx_tokens_conv - VAL) *
|
||||||
|
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||||
|
|
||||||
|
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
|
||||||
|
(idx_tokens_conv - VAL < seqlen)[:, None] &
|
||||||
|
(idx_feats < dim)[None, :]
|
||||||
|
) # token-index # token-index # feature-index
|
||||||
|
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||||
|
|
||||||
|
tl.debug_barrier(
|
||||||
|
) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
|
||||||
|
new_conv_state = tl.where(
|
||||||
|
mask, conv_state, loaded_x
|
||||||
|
) # BUG in 'tl.where' which requires a barrier before this
|
||||||
|
conv_states_ptrs_target = conv_states_base + (
|
||||||
|
idx_tokens_conv *
|
||||||
|
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||||
|
mask = (idx_tokens_conv
|
||||||
|
< state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||||
|
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||||
|
else: # load_init_state == False
|
||||||
|
# update conv_state by shifting left, BUT
|
||||||
|
# set cols prior to 'x' as zeros + cols from 'x'
|
||||||
|
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
|
||||||
|
VAL = state_len - seqlen
|
||||||
|
|
||||||
|
x_ptrs = x_base[None, :] + (
|
||||||
|
(idx_tokens_conv - VAL) *
|
||||||
|
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||||
|
|
||||||
|
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
|
||||||
|
(idx_tokens_conv - VAL < seqlen)[:, None] &
|
||||||
|
(idx_feats < dim)[None, :]
|
||||||
|
) # token-index # token-index # feature-index
|
||||||
|
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||||
|
|
||||||
|
conv_states_ptrs_target = conv_states_base + (
|
||||||
|
idx_tokens_conv *
|
||||||
|
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||||
|
mask = (idx_tokens_conv
|
||||||
|
< state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||||
|
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||||
|
|
||||||
|
else: # chunk_offset > 0
|
||||||
|
# read prior-token data from `x`
|
||||||
|
load_init_state = True
|
||||||
|
prior_tokens = x_base + (token_offset - 1) * stride_x_token
|
||||||
|
mask_w = idx_feats < dim
|
||||||
|
if KERNEL_WIDTH == 2:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
if KERNEL_WIDTH == 3:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
if KERNEL_WIDTH == 4:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
if KERNEL_WIDTH == 5:
|
||||||
|
# ruff: noqa: F841
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
||||||
|
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||||
|
|
||||||
|
if HAS_BIAS:
|
||||||
|
bias = bias_ptr + idx_feats
|
||||||
|
mask_bias = idx_feats < dim
|
||||||
|
acc_preload = tl.load(bias, mask=mask_bias,
|
||||||
|
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||||
|
else:
|
||||||
|
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||||
|
|
||||||
|
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
|
||||||
|
|
||||||
|
# PRE-LOAD WEIGHTS
|
||||||
|
mask_w = idx_feats < dim
|
||||||
|
if KERNEL_WIDTH >= 2:
|
||||||
|
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
if KERNEL_WIDTH >= 3:
|
||||||
|
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
if KERNEL_WIDTH >= 4:
|
||||||
|
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
mask_x_1d = idx_feats < dim
|
||||||
|
for idx_token in range(segment_len):
|
||||||
|
acc = acc_preload
|
||||||
|
|
||||||
|
matrix_w = w_col0
|
||||||
|
matrix_x = col0
|
||||||
|
for j in tl.static_range(KERNEL_WIDTH):
|
||||||
|
|
||||||
|
if KERNEL_WIDTH == 2:
|
||||||
|
if j == 1: # KERNEL_WIDTH-1:
|
||||||
|
matrix_w = w_col1
|
||||||
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||||
|
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||||
|
elif KERNEL_WIDTH == 3:
|
||||||
|
if j == 1:
|
||||||
|
matrix_w = w_col1
|
||||||
|
matrix_x = col1
|
||||||
|
elif j == 2:
|
||||||
|
matrix_w = w_col2
|
||||||
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||||
|
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||||
|
elif KERNEL_WIDTH == 4:
|
||||||
|
if j == 1:
|
||||||
|
matrix_w = w_col1
|
||||||
|
matrix_x = col1
|
||||||
|
elif j == 2:
|
||||||
|
matrix_w = w_col2
|
||||||
|
matrix_x = col2
|
||||||
|
elif j == 3:
|
||||||
|
matrix_w = w_col3
|
||||||
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||||
|
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||||
|
|
||||||
|
acc += matrix_x * matrix_w # [BLOCK_N]
|
||||||
|
|
||||||
|
if KERNEL_WIDTH == 2:
|
||||||
|
col0 = matrix_x
|
||||||
|
elif KERNEL_WIDTH == 3:
|
||||||
|
col0 = col1
|
||||||
|
col1 = matrix_x
|
||||||
|
elif KERNEL_WIDTH == 4:
|
||||||
|
col0 = col1
|
||||||
|
col1 = col2
|
||||||
|
col2 = matrix_x
|
||||||
|
|
||||||
|
if SILU_ACTIVATION:
|
||||||
|
acc = acc / (1 + tl.exp(-acc))
|
||||||
|
mask_1d = (idx_token < segment_len) & (
|
||||||
|
idx_feats < dim) # token-index # feature-index
|
||||||
|
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
|
||||||
|
) * stride_o_token + (idx_feats * stride_o_dim)
|
||||||
|
|
||||||
|
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||||
|
|
||||||
|
|
||||||
|
def causal_conv1d_fn(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Union[torch.Tensor, None],
|
||||||
|
conv_states: torch.Tensor,
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
cache_indices: Optional[torch.Tensor] = None,
|
||||||
|
has_initial_state: Optional[torch.Tensor] = None,
|
||||||
|
activation: Optional[str] = "silu",
|
||||||
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
|
metadata=None,
|
||||||
|
validate_data=False,
|
||||||
|
):
|
||||||
|
"""support varlen + continuous batching when x is 2D tensor
|
||||||
|
|
||||||
|
x: (dim,cu_seq_len)
|
||||||
|
cu_seq_len = total tokens of all seqs in that batch
|
||||||
sequences are concatenated from left to right for varlen
|
sequences are concatenated from left to right for varlen
|
||||||
weight: (dim, width)
|
weight: (dim, width)
|
||||||
bias: (dim,)
|
conv_states: (...,dim,width - 1) itype
|
||||||
|
updated inplace if provided
|
||||||
|
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
|
||||||
|
|
||||||
|
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
|
||||||
|
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
|
||||||
|
]
|
||||||
query_start_loc: (batch + 1) int32
|
query_start_loc: (batch + 1) int32
|
||||||
The cumulative sequence lengths of the sequences in
|
The cumulative sequence lengths of the sequences in
|
||||||
the batch, used to index into sequence. prepended by 0.
|
the batch, used to index into sequence. prepended by 0.
|
||||||
|
if
|
||||||
|
x = [5, 1, 1, 1] <- continuous batching (batch=4)
|
||||||
|
then
|
||||||
|
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
|
||||||
|
the ending index of the last sequence
|
||||||
|
[length(query_start_loc)-1 == batch]
|
||||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||||
x.shape=(dim,17)
|
x.shape=(dim,17)
|
||||||
cache_indices: (batch) int32
|
cache_indices: (batch) int32
|
||||||
@ -37,42 +400,436 @@ def causal_conv1d_fn(x: torch.Tensor,
|
|||||||
has_initial_state: (batch) bool
|
has_initial_state: (batch) bool
|
||||||
indicates whether should the kernel take the current state as initial
|
indicates whether should the kernel take the current state as initial
|
||||||
state for the calculations
|
state for the calculations
|
||||||
conv_states: (...,dim,width - 1) itype
|
[single boolean for each sequence in the batch: True or False]
|
||||||
updated inplace if provided
|
bias: (dim,)
|
||||||
activation: either None or "silu" or "swish"
|
activation: either None or "silu" or "swish" or True
|
||||||
pad_slot_id: int
|
pad_slot_id: int
|
||||||
if cache_indices is passed, lets the kernel identify padded
|
if cache_indices is passed, lets the kernel identify padded
|
||||||
entries that will not be processed,
|
entries that will not be processed,
|
||||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||||
in this case, the kernel will not process entries at
|
in this case, the kernel will not process entries at
|
||||||
indices 0 and 3
|
indices 0 and 3
|
||||||
|
|
||||||
|
out: same shape as `x`
|
||||||
out: (batch, dim, seqlen)
|
|
||||||
"""
|
"""
|
||||||
if activation not in [None, "silu", "swish"]:
|
if isinstance(activation, bool) and activation:
|
||||||
raise NotImplementedError("activation must be None, silu, or swish")
|
activation = "silu"
|
||||||
if x.stride(-1) != 1:
|
|
||||||
x = x.contiguous()
|
|
||||||
bias = bias.contiguous() if bias is not None else None
|
|
||||||
|
|
||||||
ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
|
args = None
|
||||||
cache_indices, has_initial_state, activation
|
out = torch.zeros_like(x)
|
||||||
in ["silu", "swish"], pad_slot_id)
|
if metadata is not None:
|
||||||
return x
|
cu_seqlen = metadata.cu_seqlen
|
||||||
|
nums_dict = metadata.nums_dict
|
||||||
|
#x = metadata.x
|
||||||
|
args = nums_dict
|
||||||
|
batch_ptr = metadata.batch_ptr
|
||||||
|
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||||
|
else:
|
||||||
|
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||||
|
args = seqlens
|
||||||
|
MAX_NUM_PROGRAMS = 1024
|
||||||
|
|
||||||
|
batch_ptr = torch.full(
|
||||||
|
(MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=x.device
|
||||||
|
) # tracking which seq-idx the Triton program is handling
|
||||||
|
token_chunk_offset_ptr = torch.full(
|
||||||
|
(MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=x.device
|
||||||
|
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
|
||||||
|
|
||||||
|
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
||||||
|
dim, cu_seqlen = x.shape
|
||||||
|
_, width = weight.shape
|
||||||
|
state_len = width - 1
|
||||||
|
np2_statelen = triton.next_power_of_2(state_len)
|
||||||
|
|
||||||
|
padded_batch = query_start_loc.size(0) - 1
|
||||||
|
stride_x_seq = 0
|
||||||
|
stride_x_dim = x.stride(0)
|
||||||
|
stride_x_token = x.stride(1)
|
||||||
|
stride_w_dim = weight.stride(0)
|
||||||
|
stride_w_width = weight.stride(1)
|
||||||
|
stride_istate_seq = 0
|
||||||
|
stride_istate_dim = 0
|
||||||
|
stride_istate_token = 0
|
||||||
|
num_cache_lines = 0
|
||||||
|
if conv_states is not None:
|
||||||
|
# extensions to support vLLM:
|
||||||
|
# 1. conv_states is used to replaced initial_states
|
||||||
|
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
|
||||||
|
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
||||||
|
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
||||||
|
num_cache_lines = conv_states.size(0)
|
||||||
|
assert (num_cache_lines, dim, width - 1) == conv_states.shape
|
||||||
|
stride_istate_seq = conv_states.stride(0)
|
||||||
|
stride_istate_dim = conv_states.stride(1)
|
||||||
|
stride_istate_token = conv_states.stride(2)
|
||||||
|
assert stride_istate_dim == 1
|
||||||
|
if out.dim() == 2:
|
||||||
|
stride_o_seq = 0
|
||||||
|
stride_o_dim = out.stride(0)
|
||||||
|
stride_o_token = out.stride(1)
|
||||||
|
else:
|
||||||
|
stride_o_seq = out.stride(0)
|
||||||
|
stride_o_dim = out.stride(1)
|
||||||
|
stride_o_token = out.stride(2)
|
||||||
|
|
||||||
|
if validate_data:
|
||||||
|
assert x.dim() == 2
|
||||||
|
assert query_start_loc is not None
|
||||||
|
assert query_start_loc.dim() == 1
|
||||||
|
assert x.stride(0) == 1 or x.stride(1) == 1
|
||||||
|
if bias is not None:
|
||||||
|
assert bias.dim() == 1
|
||||||
|
assert dim == bias.size(0)
|
||||||
|
if cache_indices is not None:
|
||||||
|
assert cache_indices.dim() == 1
|
||||||
|
assert padded_batch == cache_indices.size(0)
|
||||||
|
if has_initial_state is not None:
|
||||||
|
assert has_initial_state.size() == (padded_batch, )
|
||||||
|
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
|
||||||
|
assert weight.stride(1) == 1
|
||||||
|
assert (dim, width) == weight.shape
|
||||||
|
assert is_channel_last, "Need to run in channel-last layout"
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
|
||||||
|
def num_program(META, seqlens):
|
||||||
|
tot = 0
|
||||||
|
|
||||||
|
mlist = []
|
||||||
|
offsetlist = [] # type: ignore
|
||||||
|
|
||||||
|
nums = -(-seqlens // META["BLOCK_M"])
|
||||||
|
|
||||||
|
tot = nums.sum().item()
|
||||||
|
mlist = np.repeat(np.arange(len(nums)), nums)
|
||||||
|
for idx, num in enumerate(nums):
|
||||||
|
offsetlist.extend(
|
||||||
|
range(num)
|
||||||
|
) # chunk-idx if a sequence is split into multiple chunks
|
||||||
|
|
||||||
|
if META["batch_ptr"].nelement() < len(mlist):
|
||||||
|
newlen = len(mlist) + 1
|
||||||
|
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||||
|
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
|
||||||
|
PAD_SLOT_ID)
|
||||||
|
|
||||||
|
if META["batch_ptr"].nelement() >= len(mlist):
|
||||||
|
META["batch_ptr"][0:len(mlist)].copy_(
|
||||||
|
torch.from_numpy(np.array(mlist)))
|
||||||
|
META["token_chunk_offset_ptr"][0:len(mlist)].copy_(
|
||||||
|
torch.from_numpy(np.array(offsetlist)))
|
||||||
|
|
||||||
|
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
|
||||||
|
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
|
||||||
|
META["x_ptr"].device)
|
||||||
|
return tot
|
||||||
|
else:
|
||||||
|
|
||||||
|
def num_program(META, nums_dict):
|
||||||
|
tot = nums_dict[META["BLOCK_M"]]['tot']
|
||||||
|
|
||||||
|
mlist = nums_dict[META["BLOCK_M"]]['mlist']
|
||||||
|
mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len']
|
||||||
|
|
||||||
|
offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist']
|
||||||
|
|
||||||
|
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
|
||||||
|
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
|
||||||
|
META["token_chunk_offset_ptr"] = nums_dict[
|
||||||
|
META["BLOCK_M"]]["token_chunk_offset_ptr"]
|
||||||
|
else:
|
||||||
|
if META["batch_ptr"].nelement() < mlist_len:
|
||||||
|
newlen = mlist_len + 1
|
||||||
|
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||||
|
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
|
||||||
|
PAD_SLOT_ID)
|
||||||
|
|
||||||
|
if META["batch_ptr"].nelement() >= mlist_len:
|
||||||
|
META["batch_ptr"][0:mlist_len].copy_(mlist)
|
||||||
|
META["token_chunk_offset_ptr"][0:mlist_len].copy_(
|
||||||
|
offsetlist)
|
||||||
|
return tot
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
num_program(META, args),
|
||||||
|
triton.cdiv(dim, META["BLOCK_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_ptr.device != x.device:
|
||||||
|
batch_ptr = batch_ptr.to(x.device)
|
||||||
|
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
|
||||||
|
|
||||||
|
_causal_conv1d_fwd_kernel[grid](
|
||||||
|
# Pointers to matrices
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
conv_states,
|
||||||
|
cache_indices,
|
||||||
|
has_initial_state,
|
||||||
|
query_start_loc,
|
||||||
|
batch_ptr,
|
||||||
|
token_chunk_offset_ptr,
|
||||||
|
out,
|
||||||
|
# Matrix dimensions
|
||||||
|
padded_batch,
|
||||||
|
dim,
|
||||||
|
cu_seqlen,
|
||||||
|
num_cache_lines,
|
||||||
|
# stride
|
||||||
|
stride_x_seq,
|
||||||
|
stride_x_dim,
|
||||||
|
stride_x_token,
|
||||||
|
stride_w_dim,
|
||||||
|
stride_w_width,
|
||||||
|
stride_istate_seq,
|
||||||
|
stride_istate_dim,
|
||||||
|
stride_istate_token,
|
||||||
|
stride_o_seq,
|
||||||
|
stride_o_dim,
|
||||||
|
stride_o_token,
|
||||||
|
# others
|
||||||
|
pad_slot_id,
|
||||||
|
# META
|
||||||
|
HAS_BIAS=bias is not None,
|
||||||
|
KERNEL_WIDTH=width,
|
||||||
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||||
|
HAS_INITIAL_STATES=has_initial_state is not None,
|
||||||
|
HAS_CACHE=conv_states is not None,
|
||||||
|
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
||||||
|
USE_PAD_SLOT=pad_slot_id is not None,
|
||||||
|
NP2_STATELEN=np2_statelen,
|
||||||
|
DECODE_SEQLEN=1,
|
||||||
|
#launch_cooperative_grid=True
|
||||||
|
BLOCK_M=8,
|
||||||
|
BLOCK_N=256,
|
||||||
|
num_stages=2,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_update(x: torch.Tensor,
|
@triton.jit()
|
||||||
conv_state: torch.Tensor,
|
def _causal_conv1d_update_kernel(
|
||||||
weight: torch.Tensor,
|
# Pointers to matrices
|
||||||
bias: Optional[torch.Tensor] = None,
|
x_ptr, # (batch, dim, seqlen)
|
||||||
activation: Optional[str] = None,
|
w_ptr, # (dim, width)
|
||||||
cache_seqlens: Optional[torch.Tensor] = None,
|
bias_ptr,
|
||||||
conv_state_indices: Optional[torch.Tensor] = None,
|
conv_state_ptr,
|
||||||
pad_slot_id: int = PAD_SLOT_ID):
|
cache_seqlens_ptr, # circular buffer
|
||||||
|
conv_state_indices_ptr,
|
||||||
|
o_ptr, # (batch, dim, seqlen)
|
||||||
|
# Matrix dimensions
|
||||||
|
batch: int,
|
||||||
|
dim: tl.constexpr,
|
||||||
|
seqlen: tl.constexpr,
|
||||||
|
state_len: tl.constexpr,
|
||||||
|
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||||
|
# Strides
|
||||||
|
stride_x_seq: tl.constexpr,
|
||||||
|
stride_x_dim: tl.constexpr,
|
||||||
|
stride_x_token: tl.constexpr,
|
||||||
|
stride_w_dim: tl.constexpr,
|
||||||
|
stride_w_width: tl.constexpr,
|
||||||
|
stride_conv_state_seq: tl.constexpr,
|
||||||
|
stride_conv_state_dim: tl.constexpr,
|
||||||
|
stride_conv_state_tok: tl.constexpr,
|
||||||
|
stride_o_seq: tl.constexpr,
|
||||||
|
stride_o_dim: tl.constexpr,
|
||||||
|
stride_o_token: tl.constexpr,
|
||||||
|
# others
|
||||||
|
pad_slot_id: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
HAS_BIAS: tl.constexpr,
|
||||||
|
KERNEL_WIDTH: tl.constexpr,
|
||||||
|
SILU_ACTIVATION: tl.constexpr,
|
||||||
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||||
|
NP2_STATELEN: tl.constexpr,
|
||||||
|
USE_PAD_SLOT: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
# ruff: noqa: E501
|
||||||
|
idx_seq = tl.program_id(0)
|
||||||
|
if idx_seq >= batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
# [BLOCK_N,] elements along the feature-dimension (channel)
|
||||||
|
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
if IS_CONTINUOUS_BATCHING:
|
||||||
|
# mask = idx_seq < batch
|
||||||
|
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
|
||||||
|
else:
|
||||||
|
conv_state_batch_coord = idx_seq
|
||||||
|
if USE_PAD_SLOT: # noqa
|
||||||
|
if conv_state_batch_coord == pad_slot_id:
|
||||||
|
# not processing as this is not the actual sequence
|
||||||
|
return
|
||||||
|
|
||||||
|
# STEP 1: READ init_state data
|
||||||
|
conv_states_base = (conv_state_ptr +
|
||||||
|
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
|
(idx_feats * stride_conv_state_dim))
|
||||||
|
mask_w = idx_feats < dim
|
||||||
|
|
||||||
|
prior_tokens = conv_states_base
|
||||||
|
if KERNEL_WIDTH >= 2:
|
||||||
|
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||||
|
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
if KERNEL_WIDTH >= 3:
|
||||||
|
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
if KERNEL_WIDTH >= 4:
|
||||||
|
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
if KERNEL_WIDTH == 5:
|
||||||
|
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
|
||||||
|
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
||||||
|
|
||||||
|
# STEP 2: assume state_len > seqlen
|
||||||
|
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
|
||||||
|
conv_state_ptrs_source = (
|
||||||
|
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
|
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||||
|
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
|
||||||
|
) # [BLOCK_M, BLOCK_N]
|
||||||
|
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||||
|
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||||
|
& (idx_feats < dim)[None, :])
|
||||||
|
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||||
|
|
||||||
|
VAL = state_len - seqlen
|
||||||
|
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
|
||||||
|
) # [BLOCK_N]
|
||||||
|
|
||||||
|
x_ptrs = x_base[None, :] + (
|
||||||
|
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||||
|
|
||||||
|
mask_x = ((idx_tokens - VAL >= 0)[:, None] &
|
||||||
|
(idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :]
|
||||||
|
) # token-index # token-index # feature-index
|
||||||
|
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||||
|
tl.debug_barrier()
|
||||||
|
|
||||||
|
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
||||||
|
|
||||||
|
conv_state_base = (conv_state_ptr +
|
||||||
|
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
|
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||||
|
conv_state_ptrs_target = conv_state_base + (
|
||||||
|
idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||||
|
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||||
|
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||||
|
|
||||||
|
# STEP 3: init accumulator
|
||||||
|
if HAS_BIAS:
|
||||||
|
bias = bias_ptr + idx_feats
|
||||||
|
mask_bias = idx_feats < dim
|
||||||
|
acc_preload = tl.load(bias, mask=mask_bias,
|
||||||
|
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||||
|
else:
|
||||||
|
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||||
|
|
||||||
|
# STEP 4:
|
||||||
|
# PRE-LOAD WEIGHTS
|
||||||
|
# first kernel column, configured for weights to handle BLOCK_N features in range
|
||||||
|
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||||
|
mask_w = idx_feats < dim
|
||||||
|
if KERNEL_WIDTH >= 2:
|
||||||
|
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
if KERNEL_WIDTH >= 3:
|
||||||
|
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
if KERNEL_WIDTH >= 4:
|
||||||
|
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
||||||
|
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
||||||
|
|
||||||
|
x_base_1d = x_base # starting of chunk [BLOCK_N]
|
||||||
|
mask_x_1d = idx_feats < dim
|
||||||
|
|
||||||
|
# STEP 5: compute each token
|
||||||
|
for idx_token in tl.static_range(seqlen):
|
||||||
|
acc = acc_preload
|
||||||
|
|
||||||
|
matrix_w = w_col0
|
||||||
|
matrix_x = col0
|
||||||
|
for j in tl.static_range(KERNEL_WIDTH):
|
||||||
|
if KERNEL_WIDTH == 2:
|
||||||
|
if j == 1: # KERNEL_WIDTH-1:
|
||||||
|
matrix_w = w_col1
|
||||||
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||||
|
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||||
|
elif KERNEL_WIDTH == 3:
|
||||||
|
if j == 1:
|
||||||
|
matrix_w = w_col1
|
||||||
|
matrix_x = col1
|
||||||
|
elif j == 2:
|
||||||
|
matrix_w = w_col2
|
||||||
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||||
|
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||||
|
elif KERNEL_WIDTH == 4:
|
||||||
|
if j == 1:
|
||||||
|
matrix_w = w_col1
|
||||||
|
matrix_x = col1
|
||||||
|
elif j == 2:
|
||||||
|
matrix_w = w_col2
|
||||||
|
matrix_x = col2
|
||||||
|
elif j == 3:
|
||||||
|
matrix_w = w_col3
|
||||||
|
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
||||||
|
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
||||||
|
|
||||||
|
acc += matrix_x * matrix_w # [BLOCK_N]
|
||||||
|
|
||||||
|
if KERNEL_WIDTH == 2:
|
||||||
|
col0 = matrix_x
|
||||||
|
elif KERNEL_WIDTH == 3:
|
||||||
|
col0 = col1
|
||||||
|
col1 = matrix_x
|
||||||
|
elif KERNEL_WIDTH == 4:
|
||||||
|
col0 = col1
|
||||||
|
col1 = col2
|
||||||
|
col2 = matrix_x
|
||||||
|
|
||||||
|
if SILU_ACTIVATION:
|
||||||
|
acc = acc / (1 + tl.exp(-acc))
|
||||||
|
mask_1d = (idx_token < seqlen) & (idx_feats < dim
|
||||||
|
) # token-index # feature-index
|
||||||
|
o_ptrs = o_ptr + (
|
||||||
|
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
|
||||||
|
idx_feats * stride_o_dim)
|
||||||
|
|
||||||
|
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||||
|
|
||||||
|
|
||||||
|
def causal_conv1d_update(
|
||||||
|
x: torch.Tensor,
|
||||||
|
conv_state: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: Union[bool, str, None] = None,
|
||||||
|
cache_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
conv_state_indices: Optional[torch.Tensor] = None,
|
||||||
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
|
metadata=None,
|
||||||
|
validate_data=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
x: (batch, dim) or (batch, dim, seqlen)
|
x: (batch, dim) or (batch, dim, seqlen)
|
||||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
[shape=2: single token prediction]
|
||||||
|
[shape=3: single or multiple tokens prediction]
|
||||||
|
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||||
weight: (dim, width)
|
weight: (dim, width)
|
||||||
bias: (dim,)
|
bias: (dim,)
|
||||||
cache_seqlens: (batch,), dtype int32.
|
cache_seqlens: (batch,), dtype int32.
|
||||||
@ -92,14 +849,98 @@ def causal_conv1d_update(x: torch.Tensor,
|
|||||||
indices 0 and 3
|
indices 0 and 3
|
||||||
out: (batch, dim) or (batch, dim, seqlen)
|
out: (batch, dim) or (batch, dim, seqlen)
|
||||||
"""
|
"""
|
||||||
if activation not in [None, "silu", "swish"]:
|
if validate_data:
|
||||||
raise NotImplementedError("activation must be None, silu, or swish")
|
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||||
activation_val = activation in ["silu", "swish"]
|
assert pad_slot_id is not None
|
||||||
|
assert x.stride(1) == 1
|
||||||
|
if isinstance(activation, bool):
|
||||||
|
activation = "silu" if activation is True else None
|
||||||
|
elif activation is not None:
|
||||||
|
assert activation in ["silu", "swish"]
|
||||||
unsqueeze = x.dim() == 2
|
unsqueeze = x.dim() == 2
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
|
# make it (batch, dim, seqlen) with seqlen == 1
|
||||||
x = x.unsqueeze(-1)
|
x = x.unsqueeze(-1)
|
||||||
ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
|
batch, dim, seqlen = x.shape
|
||||||
cache_seqlens, conv_state_indices, pad_slot_id)
|
_, width = weight.shape
|
||||||
|
# conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||||
|
num_cache_lines, _, state_len = conv_state.size()
|
||||||
|
|
||||||
|
if validate_data:
|
||||||
|
assert dim == weight.size(0)
|
||||||
|
assert conv_state.stride(
|
||||||
|
-2
|
||||||
|
) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
||||||
|
assert state_len >= width - 1
|
||||||
|
# when above happens, we don't shift-left to keep any records in conv_state
|
||||||
|
assert dim == conv_state.size(1)
|
||||||
|
if conv_state_indices is None:
|
||||||
|
assert conv_state.size(0) >= batch
|
||||||
|
else:
|
||||||
|
assert (batch, ) == conv_state_indices.shape
|
||||||
|
|
||||||
|
assert num_cache_lines >= batch
|
||||||
|
assert weight.stride(1) == 1 # Need this
|
||||||
|
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
||||||
|
|
||||||
|
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||||
|
out = x
|
||||||
|
stride_w_dim, stride_w_width = weight.stride()
|
||||||
|
|
||||||
|
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
|
||||||
|
) # X (batch, dim, seqlen)
|
||||||
|
|
||||||
|
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
||||||
|
|
||||||
|
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||||
|
)
|
||||||
|
state_len = width - 1
|
||||||
|
np2_statelen = triton.next_power_of_2(state_len)
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
batch,
|
||||||
|
triton.cdiv(dim, META["BLOCK_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
_causal_conv1d_update_kernel[grid](
|
||||||
|
# Pointers to matrices
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
conv_state,
|
||||||
|
cache_seqlens,
|
||||||
|
conv_state_indices,
|
||||||
|
out,
|
||||||
|
# Matrix dimensions
|
||||||
|
batch,
|
||||||
|
dim,
|
||||||
|
seqlen,
|
||||||
|
state_len,
|
||||||
|
num_cache_lines,
|
||||||
|
# stride
|
||||||
|
stride_x_seq,
|
||||||
|
stride_x_dim,
|
||||||
|
stride_x_token,
|
||||||
|
stride_w_dim,
|
||||||
|
stride_w_width,
|
||||||
|
stride_istate_seq,
|
||||||
|
stride_istate_dim,
|
||||||
|
stride_istate_token,
|
||||||
|
stride_o_seq,
|
||||||
|
stride_o_dim,
|
||||||
|
stride_o_token,
|
||||||
|
# others
|
||||||
|
pad_slot_id,
|
||||||
|
# META
|
||||||
|
HAS_BIAS=bias is not None,
|
||||||
|
KERNEL_WIDTH=width,
|
||||||
|
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||||
|
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||||
|
NP2_STATELEN=np2_statelen,
|
||||||
|
USE_PAD_SLOT=pad_slot_id is not None,
|
||||||
|
BLOCK_N=256,
|
||||||
|
)
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
x = x.squeeze(-1)
|
out = out.squeeze(-1)
|
||||||
return x
|
return out
|
||||||
|
|||||||
@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache):
|
|||||||
# Initialize parent class
|
# Initialize parent class
|
||||||
super().__init__(max_batch_size)
|
super().__init__(max_batch_size)
|
||||||
|
|
||||||
|
# assume conv_state = (dim, state_len)
|
||||||
|
assert conv_state_shape[0] > conv_state_shape[1]
|
||||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
conv_state_shape,
|
(conv_state_shape[1], conv_state_shape[0]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device="cuda")
|
device="cuda").transpose(-1, -2)
|
||||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
temporal_state_shape,
|
temporal_state_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.kv_cache_interface import MambaSpec
|
from vllm.v1.kv_cache_interface import MambaSpec
|
||||||
@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
|
|||||||
return chunk_sizes.pop()
|
return chunk_sizes.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
total_seqlens: int):
|
||||||
|
|
||||||
|
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||||
|
|
||||||
|
# outputs will have length expansion of chunks that do not divide
|
||||||
|
# chunk_size
|
||||||
|
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||||
|
> 0).sum()
|
||||||
|
chunk_indices = torch.arange(N,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
chunk_offsets = torch.zeros((N, ),
|
||||||
|
dtype=torch.int,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
|
||||||
|
p = 0 # num of insertions
|
||||||
|
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||||
|
|
||||||
|
# if does not divide chunk_size, then there is one chunk insertion
|
||||||
|
p += (s % chunk_size > 0)
|
||||||
|
|
||||||
|
# get the dimensions
|
||||||
|
# - the + 1 for _e is to shift the boundary by one chunk
|
||||||
|
# - this shifting is not needed if chunk_size divides e
|
||||||
|
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||||
|
> 0)
|
||||||
|
|
||||||
|
# adjust indices and offsets
|
||||||
|
chunk_indices[_s:_e] -= p
|
||||||
|
chunk_offsets[_s] = s % chunk_size
|
||||||
|
|
||||||
|
return chunk_indices, chunk_offsets
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionBackend(AttentionBackend):
|
class Mamba2AttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -53,6 +88,10 @@ class Mamba2AttentionMetadata:
|
|||||||
chunk_offsets: torch.Tensor
|
chunk_offsets: torch.Tensor
|
||||||
|
|
||||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||||
|
nums_dict: Optional[dict] = None
|
||||||
|
cu_seqlen: Optional[int] = None
|
||||||
|
batch_ptr: Optional[torch.tensor] = None
|
||||||
|
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user