[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tuan, Hoang-Trong 2025-07-09 15:53:55 -04:00 committed by GitHub
parent 31b96d1c64
commit 47043eb678
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1117 additions and 1142 deletions

View File

@ -232,7 +232,6 @@ endif()
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu" "csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/cache_kernels.cu" "csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu" "csrc/attention/paged_attention_v2.cu"

View File

@ -1,656 +0,0 @@
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "causal_conv1d.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#ifdef USE_ROCM
namespace cub = hipcub;
#endif
#include "static_switch.h"
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}
template<typename input_t, typename weight_t>
void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
template<typename input_t, typename weight_t>
void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
void set_conv_params_fwd(ConvParamsBase &params,
// sizes
const size_t batch,
const size_t dim,
const size_t seqlen,
const size_t width,
// device pointers
const at::Tensor x,
const at::Tensor weight,
const at::Tensor out,
const std::optional<at::Tensor>& bias,
bool silu_activation,
int64_t pad_slot_id,
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
const std::optional<at::Tensor>& cache_indices = std::nullopt,
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
// Reset the parameters
memset(&params, 0, sizeof(params));
params.batch = batch;
params.dim = dim;
params.seqlen = seqlen;
params.width = width;
params.pad_slot_id = pad_slot_id;
params.silu_activation = silu_activation;
// Set the pointers and strides.
params.x_ptr = x.data_ptr();
params.weight_ptr = weight.data_ptr();
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
params.out_ptr = out.data_ptr();
// All stride are in elements, not bytes.
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
const bool varlen = params.query_start_loc_ptr != nullptr;
params.x_batch_stride = x.stride(varlen ? 1 : 0);
params.x_c_stride = x.stride(varlen ? 0 : 1);
params.x_l_stride = x.stride(varlen ? 1 : -1);
params.weight_c_stride = weight.stride(0);
params.weight_width_stride = weight.stride(1);
params.out_batch_stride = out.stride(varlen ? 1 : 0);
params.out_c_stride = out.stride(varlen ? 0 : 1);
params.out_l_stride = out.stride(varlen ? 1 : -1);
}
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const std::optional<at::Tensor> &bias_,
const std::optional<at::Tensor> &conv_states,
const std::optional<at::Tensor> &query_start_loc,
const std::optional<at::Tensor> &cache_indices,
const std::optional<at::Tensor> &has_initial_state,
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(weight.is_cuda());
const bool varlen = query_start_loc.has_value() ? true : false;
const auto sizes = x.sizes();
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = varlen ? sizes[1] : sizes[2];
const int width = weight.size(-1);
if (varlen){
CHECK_SHAPE(x, dim, seqlen);
}
else {
CHECK_SHAPE(x, batch_size, dim, seqlen);
}
CHECK_SHAPE(weight, dim, width);
if (bias_.has_value()) {
auto bias = bias_.value();
TORCH_CHECK(bias.scalar_type() == weight_type);
TORCH_CHECK(bias.is_cuda());
TORCH_CHECK(bias.stride(-1) == 1);
CHECK_SHAPE(bias, dim);
}
if (has_initial_state.has_value()) {
auto has_initial_state_ = has_initial_state.value();
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
TORCH_CHECK(has_initial_state_.is_cuda());
CHECK_SHAPE(has_initial_state_, batch_size);
}
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
}
if (cache_indices.has_value()) {
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
}
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id,
query_start_loc,
cache_indices,
has_initial_state
);
if (conv_states.has_value()) {
auto conv_states_ = conv_states.value();
TORCH_CHECK(conv_states_.scalar_type() == input_type);
TORCH_CHECK(conv_states_.is_cuda());
params.conv_states_ptr = conv_states_.data_ptr();
params.conv_states_batch_stride = conv_states_.stride(0);
params.conv_states_c_stride = conv_states_.stride(1);
params.conv_states_l_stride = conv_states_.stride(2);
} else {
params.conv_states_ptr = nullptr;
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
});
}
void causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const std::optional<at::Tensor> &bias_,
bool silu_activation,
const std::optional<at::Tensor> &cache_seqlens_,
const std::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
TORCH_CHECK(conv_state.scalar_type() == input_type);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(conv_state.is_cuda());
TORCH_CHECK(weight.is_cuda());
const auto sizes = x.sizes();
const int batch_size = sizes[0];
const int dim = sizes[1];
const int seqlen = sizes[2];
const int width = weight.size(-1);
const int conv_state_len = conv_state.size(2);
TORCH_CHECK(conv_state_len >= width - 1);
CHECK_SHAPE(x, batch_size, dim, seqlen);
CHECK_SHAPE(weight, dim, width);
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
if (bias_.has_value()) {
auto bias = bias_.value();
TORCH_CHECK(bias.scalar_type() == weight_type);
TORCH_CHECK(bias.is_cuda());
TORCH_CHECK(bias.stride(-1) == 1);
CHECK_SHAPE(bias, dim);
}
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes.
params.conv_state_batch_stride = conv_state.stride(0);
params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2);
if (cache_seqlens_.has_value()) {
auto cache_seqlens = cache_seqlens_.value();
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
TORCH_CHECK(cache_seqlens.is_cuda());
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
CHECK_SHAPE(cache_seqlens, batch_size);
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
} else {
params.cache_seqlens = nullptr;
}
if (conv_state_indices_.has_value()) {
auto conv_state_indices = conv_state_indices_.value();
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
TORCH_CHECK(conv_state_indices.is_cuda());
TORCH_CHECK(conv_state_indices.stride(0) == 1)
CHECK_SHAPE(conv_state_indices, batch_size);
int conv_state_entries = conv_state.size(0);
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
} else {
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
params.conv_state_indices_ptr = nullptr;
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
});
}
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
struct Causal_conv1d_fwd_kernel_traits {
using input_t = input_t_;
using weight_t = weight_t_;
static constexpr int kNThreads = kNThreads_;
static constexpr int kWidth = kWidth_;
static constexpr int kNBytes = sizeof(input_t);
static_assert(kNBytes == 2 || kNBytes == 4);
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
static_assert(kWidth <= kNElts);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
static constexpr int kSmemIOSize = kIsVecLoad
? 0
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
};
template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNElts = Ktraits::kNElts;
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t;
// Shared memory.
extern __shared__ char smem_[];
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
const bool kVarlen = params.query_start_loc_ptr != nullptr;
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y;
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
+ channel_id * params.x_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if (tidx == 0) {
input_t initial_state[kNElts] = {0};
if (has_initial_state) {
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
}
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
}
float weight_vals[kWidth];
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
constexpr int kChunkSize = kNThreads * kNElts;
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t x_vals_load[2 * kNElts] = {0};
if constexpr(kIsVecLoad) {
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
} else {
__syncthreads();
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
}
x += kChunkSize;
__syncthreads();
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
// the last elements of the previous chunk.
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
__syncthreads();
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
__syncthreads();
// Now thread kNThreads - 1 can write the last elements of the current chunk.
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
float x_vals[2 * kNElts];
#pragma unroll
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
float out_vals[kNElts];
#pragma unroll
for (int i = 0; i < kNElts; ++i) {
out_vals[i] = bias_val;
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
}
}
if (params.silu_activation) {
#pragma unroll
for (int i = 0; i < kNElts; ++i) {
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
}
}
input_t out_vals_store[kNElts];
#pragma unroll
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
if constexpr(kIsVecLoad) {
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
} else {
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
}
out += kChunkSize;
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positive index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
input_t vals_load[kNElts] = {0};
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
#pragma unroll
for (int w = 0; w < -final_state_position; ++w){
conv_states[w] = vals_load[kNElts + final_state_position + w];
}
}
if ((chunk == n_chunks - 1) && tidx == 0){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
for (int w = -final_state_position; w < kWidth - 1; ++w){
conv_states[w] = vals_load[w + final_state_position];
}
return;
}
}
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
if (conv_states != nullptr && tidx == last_thread) {
input_t x_vals_load[kNElts * 2] = {0};
// in case we are on the first kWidth tokens
if (last_thread == 0 && seqlen < kWidth){
// Need to take the initial state
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
const int offset = seqlen - (kWidth - 1);
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
// pad the existing state
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
}
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
if (offset + w >= 0)
conv_states[w] = x_vals_load[offset + w ];
}
}
else {
// in case the final state is in between the threads data
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
}
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
conv_states[w] = x_vals_load[offset + w ];
}
}
}
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
const bool kVarlen = params.query_start_loc_ptr != nullptr;
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch, params.dim);
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
template<typename input_t, typename weight_t>
void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
if (params.width == 2) {
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
} else if (params.width == 3) {
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
} else if (params.width == 4) {
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
}
}
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
struct Causal_conv1d_update_kernel_traits {
using input_t = input_t_;
using weight_t = weight_t_;
static constexpr int kNThreads = kNThreads_;
static constexpr int kWidth = kWidth_;
static constexpr int kNBytes = sizeof(input_t);
static_assert(kNBytes == 2 || kNBytes == 4);
};
template<typename Ktraits, bool kIsCircularBuffer>
__global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_update_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y * kNThreads + tidx;
if (channel_id >= params.dim) return;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
int state_len = params.conv_state_len;
int advance_len = params.seqlen;
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
int update_idx = cache_seqlen - (kWidth - 1);
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
float weight_vals[kWidth] = {0};
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
float x_vals[kWidth] = {0};
if constexpr (!kIsCircularBuffer) {
#pragma unroll 2
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
}
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) {
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
}
x_vals[i] = float(state_val);
}
} else {
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
x_vals[i] = float(state_val);
}
}
#pragma unroll 2
for (int i = 0; i < params.seqlen; ++i) {
input_t x_val = x[i * params.x_l_stride];
if constexpr (!kIsCircularBuffer) {
if (i < advance_len && state_len - advance_len + i >= 0) {
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
}
} else {
conv_state[update_idx * params.conv_state_l_stride] = x_val;
++update_idx;
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
}
x_vals[kWidth - 1] = float(x_val);
float out_val = bias_val;
#pragma unroll
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
out[i * params.out_l_stride] = input_t(out_val);
// Shift the input buffer by 1
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
}
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
auto kernel = params.cache_seqlens == nullptr
? &causal_conv1d_update_kernel<Ktraits, false>
: &causal_conv1d_update_kernel<Ktraits, true>;
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename input_t, typename weight_t>
void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
if (params.width == 2) {
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
} else if (params.width == 3) {
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
} else if (params.width == 4) {
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
}
}
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);

View File

@ -1,159 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation;
index_t x_batch_stride;
index_t x_c_stride;
index_t x_l_stride;
index_t weight_c_stride;
index_t weight_width_stride;
index_t out_batch_stride;
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
// Common data pointers.
void *__restrict__ x_ptr;
void *__restrict__ weight_ptr;
void *__restrict__ bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;
void *__restrict__ seq_idx_ptr;
// No __restrict__ since initial_states could be the same as final_states.
void * initial_states_ptr;
index_t initial_states_batch_stride;
index_t initial_states_l_stride;
index_t initial_states_c_stride;
void * final_states_ptr;
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};
#ifndef USE_ROCM
#include <cuda_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor_sync(uint32_t(-1), val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return std::max(ilist);
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return std::min(a, b);
}
#else
#include <hip/hip_bf16.h>
template<typename T>
__device__ inline T shuffle_xor(T val, int offset) {
return __shfl_xor(val, offset);
}
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
{
return *std::max_element(ilist.begin(), ilist.end());
}
template<typename T>
constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES> struct BytesToType {};
template<> struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<> struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<> struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<> struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<> struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};

View File

@ -1,28 +0,0 @@
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// clang-format off
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()

View File

@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const std::optional<torch::Tensor>& has_initial_state, const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id); const torch::Tensor& ssm_states, int64_t pad_slot_id);
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
bool silu_activation,
const std::optional<at::Tensor>& cache_seqlens_,
const std::optional<at::Tensor>& conv_state_indices_,
int64_t pad_slot_id);
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const std::optional<at::Tensor>& bias_,
const std::optional<at::Tensor>& conv_states,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool silu_activation, int64_t pad_slot_id);
using fptr_t = int64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs, fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank, torch::Tensor& rank_data, int64_t rank,

View File

@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int pad_slot_id) -> ()"); "int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#ifndef USE_ROCM #ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def( ops.def(

View File

@ -6,9 +6,8 @@ from typing import Optional
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
x = x.contiguous() x = x.contiguous()
bias = bias.contiguous() if bias is not None else None bias = bias.contiguous() if bias is not None else None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("has_initial_state", [True, False])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
has_initial_state, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_state:
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
has_initial_state_tensor = torch.ones(batch,
dtype=torch.bool,
device=x.device)
else:
initial_states = None
has_initial_state_tensor = None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=True,
activation=activation)
if has_initial_state:
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=has_initial_state_tensor)
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("silu_activation", [False, True])
@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref) assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, None, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", @pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded # tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, @pytest.mark.parametrize("batch_size", [3])
seqlen, has_bias, def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype): silu_activation, itype):
device = "cuda" device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
# set seed # set seed
current_platform.seed_everything(0) current_platform.seed_everything(0)
batch_size = 3
padding = 5 if with_padding else 0 padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size total_entries = 10 * batch_size
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) # x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x_ref = x.clone() x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to( conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
], ],
dim=0) dim=0)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries, conv_state = torch.randn(total_entries,
dim,
width - 1, width - 1,
dim,
device=device, device=device,
dtype=itype) dtype=itype).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone() conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone() conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, out = causal_conv1d_update(x,
conv_state, conv_state,
weight, weight,
@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
activation=activation) activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool], assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool]) conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize( @pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096]) @pytest.mark.parametrize('dim', [64, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False]) @pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, @pytest.mark.parametrize('batch', [4, 10])
silu_activation, itype): def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
device = "cuda" device = "cuda"
torch.cuda.empty_cache() torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
# set seed # set seed
current_platform.seed_everything(0) current_platform.seed_everything(0)
seqlens = [] seqlens = []
batch_size = 4 batch_size = batch
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0 padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1 nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append( seqlens.append(
torch.diff( torch.diff(
torch.cat( torch.cat(
@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0) dim=0)
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, x = rearrange(
dtype=itype)[:, 4096:4096 + dim, :] torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s")[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone() x_ref = x.clone()
weight_ref = weight.clone() weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu" activation = None if not silu_activation else "silu"
final_states = torch.randn(total_entries, final_states = torch.randn(total_entries,
dim,
width - 1, width - 1,
dim,
device=x.device, device=x.device,
dtype=x.dtype) dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone() final_states_ref = final_states.clone()
has_initial_states = torch.randint(0, has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ), 2, (cumsum.shape[0] - 1, ),
@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
], ],
dim=-1) dim=-1)
out = causal_conv1d_fn(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation, PAD_SLOT_ID)
out_ref = [] out_ref = []
out_ref_b = [] out_ref_b = []
@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0) out_ref_tensor = torch.cat(out_ref, dim=0)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states[state_indices], assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices], final_states_ref[state_indices],
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
padded_state_indices, has_initial_states,
final_states, activation)

View File

@ -6,11 +6,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
_query_start_loc_to_chunk_indices_offsets)
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024

View File

@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
# mamba # mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int):
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens,
conv_state_indices, pad_slot_id)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],

View File

@ -1,14 +1,18 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.placeholder_attn import ( from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata) PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@dataclass @dataclass
@ -21,6 +25,29 @@ class Mamba2Metadata:
seq_idx: torch.Tensor seq_idx: torch.Tensor
chunk_indices: torch.Tensor chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor chunk_offsets: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
(batch_ptr, token_chunk_offset_ptr)
BLOCK_M = the # tokens to be handled by a Triton program
(can be customized for different hardware)
nums_dict:
tracks the data associated with a given value of BLOCK_M
BLOCK_M = #tokens handled by a Triton program
cu_seqlen: total tokens per batch
(used as flag to update other data at each new input)
batch_ptr: tracks batch-id handled by the Triton program
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
(Triton implementation of causal_conv1d handles parallelism in 3-axes
- feature-axis
- batch-axis
- sequence-axis)
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
f"Unsupported platform for Mamba2: {current_platform.device_type}") f"Unsupported platform for Mamba2: {current_platform.device_type}")
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def prepare_mamba2_metadata( def prepare_mamba2_metadata(
chunk_size: int, chunk_size: int,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
mamba2_metadata=None,
) -> Mamba2Metadata: ) -> Mamba2Metadata:
# compute number of prefill and decode requests # compute number of prefill and decode requests
@ -96,12 +88,12 @@ def prepare_mamba2_metadata(
attn_metadata_instances = get_platform_metadata_classes() attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances) if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None): and attn_metadata.context_lens_tensor is not None):
has_initial_states = \ # precompute flag to avoid device syncs later in mamba2 layer
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] # forwards
# precompute flag to avoid device syncs in mamba2 layer forwards
# prep is only needed for mamba2 ssd prefill processing # prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(has_initial_states).item() has_initial_states = attn_metadata.context_lens_tensor > 0
prep_initial_states = torch.any(
has_initial_states[:num_prefills]).item()
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange( seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device), num_prefills, dtype=torch.int32, device=query_start_loc.device),
@ -117,9 +109,78 @@ def prepare_mamba2_metadata(
_query_start_loc_to_chunk_indices_offsets( _query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens) query_start_loc, chunk_size, num_prefill_tokens)
if mamba2_metadata is not None:
mamba2_metadata.has_initial_states = has_initial_states
mamba2_metadata.prep_initial_states = prep_initial_states
mamba2_metadata.chunk_size = chunk_size
mamba2_metadata.seq_idx = seq_idx
mamba2_metadata.chunk_indices = chunk_indices
mamba2_metadata.chunk_offsets = chunk_offsets
# We use 1 reset flag:
# * mamba2_metadata.cu_seqlen is None
# update config specific to (each input)
# (become available at first layer, e.g. conv_weights)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states, return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states, prep_initial_states=prep_initial_states,
chunk_size=chunk_size, chunk_size=chunk_size,
seq_idx=seq_idx, seq_idx=seq_idx,
chunk_indices=chunk_indices, chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets) chunk_offsets=chunk_offsets)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""
dim, cu_seqlen = x.shape
mamba2_metadata.cu_seqlen = cu_seqlen
seqlens = np.diff(query_start_loc.to('cpu'))
nums_dict = {} # type: ignore
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
nums_dict[BLOCK_M]['nums'] = nums
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
nums_dict[BLOCK_M]['mlist'] = mlist
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
offsetlist = [] # type: ignore
for idx, num in enumerate(nums):
offsetlist.extend(range(num))
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
if mamba2_metadata.batch_ptr is None:
# Update default value after class definition
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
mamba2_metadata.token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
else:
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
PAD_SLOT_ID)
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
0:mlist_len].copy_(offsetlist)
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
mamba2_metadata.nums_dict = nums_dict
return mamba2_metadata

View File

@ -159,7 +159,7 @@ class MambaMixer(CustomOp):
hidden_states = causal_conv1d_fn( hidden_states = causal_conv1d_fn(
hidden_states, hidden_states,
conv_weights, conv_weights,
self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=mamba_cache_params.conv_state, conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0, has_initial_state=attn_metadata.context_lens_tensor > 0,

View File

@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
if attn_metadata is not None: if attn_metadata is not None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
mamba2_metadata = attn_metadata
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0] # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states has_initial_states_p = attn_metadata.has_initial_states
@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill # NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input # Separate prefill and decode by splitting varlen input
# Split along token dimension # Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split( hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C, hidden_states_B_C,
@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor" # pointed to by "state_indices_tensor"
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
if mamba2_metadata.cu_seqlen is None:
mamba2_metadata = update_metadata(
x, attn_metadata.query_start_loc, mamba2_metadata)
hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1), x,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p) hidden_states_B_C_p)
@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
# - heads and n_groups are TP-ed # - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size + conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size) 2 * n_groups * self.ssm_state_size)
# contiguous along 'dim' axis
conv_state_shape = ( conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1, self.conv_kernel_size - 1,
divide(conv_dim, world_size),
) )
# These are not TP-ed as they depend on A, dt_bias, D # These are not TP-ed as they depend on A, dt_bias, D

View File

@ -4,31 +4,394 @@
# Copyright (c) 2024, Tri Dao. # Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from typing import Optional from typing import Optional, Union
import numpy as np
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_fn(x: torch.Tensor, @triton.jit()
weight: torch.Tensor, def _causal_conv1d_fwd_kernel( # continuous batching
bias: Optional[torch.Tensor] = None, # Pointers to matrices
query_start_loc: Optional[torch.Tensor] = None, x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
cache_indices: Optional[torch.Tensor] = None, w_ptr, # (dim, width)
has_initial_state: Optional[torch.Tensor] = None, bias_ptr,
conv_states: Optional[torch.Tensor] = None, initial_states_ptr, # conv_states_ptr
activation: Optional[str] = "silu", cache_indices_ptr, # conv_state_indices_ptr
pad_slot_id: int = PAD_SLOT_ID): has_initial_states_ptr,
""" query_start_loc_ptr,
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen batch_ptr,
token_chunk_offset_ptr,
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
batch: tl.int32, # actually padded_batch
dim: tl.constexpr,
seqlen: tl.int32, # cu_seqlen
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr, # stride to get to next sequence,
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.
constexpr, # stride to get to next token (same feature-index, same sequence-index)
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
stride_w_width: tl.constexpr, # stride to get to next width-axis value
stride_istate_seq: tl.constexpr,
stride_istate_dim: tl.constexpr,
stride_istate_token: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
HAS_CACHE: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
DECODE_SEQLEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
conv_states_ptr = initial_states_ptr
conv_state_indices_ptr = cache_indices_ptr
stride_conv_state_seq = stride_istate_seq
stride_conv_state_dim = stride_istate_dim
stride_conv_state_tok = stride_istate_token
state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value
# one program handles one chunk in a single sequence
# rather than mixing sequences - to make updating initial_states across sequences efficiently
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0))
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
# BLOCK_N elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if idx_seq == pad_slot_id:
return
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
else:
# cache_idx
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = (conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
# Does 2 things:
# 1. READ prior-block init-state data - [done by every Triton programs]
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
# read from conv_states
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
tl.int1)
if load_init_state:
# load from conv_states
prior_tokens = conv_states_base + (state_len -
1) * stride_conv_state_tok
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 3:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 4:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
else:
# prior-tokens are zeros
if KERNEL_WIDTH >= 2: # STRATEGY1
# first chunk and does not have prior-token, so just set to 0
col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 3: # STRATEGY1
col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 4: # STRATEGY1
col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 5: # STRATEGY1
col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
# STEP 2:
# here prepare data for updating conv_state
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
# just read from 'x'
# copy 'x' data to conv_state
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
idx_tokens_last = (seqlen - state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (
(sequence_start_index + idx_tokens_last) *
stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = ((idx_tokens_last >= 0)[:, None] &
(idx_tokens_last < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = conv_states_base[None, :] + (
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else:
if load_init_state:
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_source = (
conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier(
) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv
< state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else: # load_init_state == False
# update conv_state by shifting left, BUT
# set cols prior to 'x' as zeros + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv
< state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else: # chunk_offset > 0
# read prior-token data from `x`
load_init_state = True
prior_tokens = x_base + (token_offset - 1) * stride_x_token
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
if KERNEL_WIDTH == 3:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
if KERNEL_WIDTH == 4:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
if KERNEL_WIDTH == 5:
# ruff: noqa: F841
conv_states_ptrs = prior_tokens # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
# PRE-LOAD WEIGHTS
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
mask_x_1d = idx_feats < dim
for idx_token in range(segment_len):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < segment_len) & (
idx_feats < dim) # token-index # feature-index
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
) * stride_o_token + (idx_feats * stride_o_dim)
tl.store(o_ptrs, acc, mask=mask_1d)
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, None],
conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""support varlen + continuous batching when x is 2D tensor
x: (dim,cu_seq_len)
cu_seq_len = total tokens of all seqs in that batch
sequences are concatenated from left to right for varlen sequences are concatenated from left to right for varlen
weight: (dim, width) weight: (dim, width)
bias: (dim,) conv_states: (...,dim,width - 1) itype
updated inplace if provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
]
query_start_loc: (batch + 1) int32 query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0. the batch, used to index into sequence. prepended by 0.
if
x = [5, 1, 1, 1] <- continuous batching (batch=4)
then
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
the ending index of the last sequence
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]), for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17) x.shape=(dim,17)
cache_indices: (batch) int32 cache_indices: (batch) int32
@ -37,42 +400,436 @@ def causal_conv1d_fn(x: torch.Tensor,
has_initial_state: (batch) bool has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial indicates whether should the kernel take the current state as initial
state for the calculations state for the calculations
conv_states: (...,dim,width - 1) itype [single boolean for each sequence in the batch: True or False]
updated inplace if provided bias: (dim,)
activation: either None or "silu" or "swish" activation: either None or "silu" or "swish" or True
pad_slot_id: int pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded if cache_indices is passed, lets the kernel identify padded
entries that will not be processed, entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at in this case, the kernel will not process entries at
indices 0 and 3 indices 0 and 3
out: same shape as `x`
out: (batch, dim, seqlen)
""" """
if activation not in [None, "silu", "swish"]: if isinstance(activation, bool) and activation:
raise NotImplementedError("activation must be None, silu, or swish") activation = "silu"
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, args = None
cache_indices, has_initial_state, activation out = torch.zeros_like(x)
in ["silu", "swish"], pad_slot_id) if metadata is not None:
return x cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict
#x = metadata.x
args = nums_dict
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
else:
seqlens = np.diff(query_start_loc.to('cpu'))
args = seqlens
MAX_NUM_PROGRAMS = 1024
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=x.device
) # tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=x.device
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
dim, cu_seqlen = x.shape
_, width = weight.shape
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
padded_batch = query_start_loc.size(0) - 1
stride_x_seq = 0
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
stride_w_width = weight.stride(1)
stride_istate_seq = 0
stride_istate_dim = 0
stride_istate_token = 0
num_cache_lines = 0
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
assert (num_cache_lines, dim, width - 1) == conv_states.shape
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
assert stride_istate_dim == 1
if out.dim() == 2:
stride_o_seq = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
else:
stride_o_seq = out.stride(0)
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
if validate_data:
assert x.dim() == 2
assert query_start_loc is not None
assert query_start_loc.dim() == 1
assert x.stride(0) == 1 or x.stride(1) == 1
if bias is not None:
assert bias.dim() == 1
assert dim == bias.size(0)
if cache_indices is not None:
assert cache_indices.dim() == 1
assert padded_batch == cache_indices.size(0)
if has_initial_state is not None:
assert has_initial_state.size() == (padded_batch, )
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
if metadata is None:
def num_program(META, seqlens):
tot = 0
mlist = []
offsetlist = [] # type: ignore
nums = -(-seqlens // META["BLOCK_M"])
tot = nums.sum().item()
mlist = np.repeat(np.arange(len(nums)), nums)
for idx, num in enumerate(nums):
offsetlist.extend(
range(num)
) # chunk-idx if a sequence is split into multiple chunks
if META["batch_ptr"].nelement() < len(mlist):
newlen = len(mlist) + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= len(mlist):
META["batch_ptr"][0:len(mlist)].copy_(
torch.from_numpy(np.array(mlist)))
META["token_chunk_offset_ptr"][0:len(mlist)].copy_(
torch.from_numpy(np.array(offsetlist)))
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
META["x_ptr"].device)
return tot
else:
def num_program(META, nums_dict):
tot = nums_dict[META["BLOCK_M"]]['tot']
mlist = nums_dict[META["BLOCK_M"]]['mlist']
mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len']
offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist']
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
META["token_chunk_offset_ptr"] = nums_dict[
META["BLOCK_M"]]["token_chunk_offset_ptr"]
else:
if META["batch_ptr"].nelement() < mlist_len:
newlen = mlist_len + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= mlist_len:
META["batch_ptr"][0:mlist_len].copy_(mlist)
META["token_chunk_offset_ptr"][0:mlist_len].copy_(
offsetlist)
return tot
def grid(META):
return (
num_program(META, args),
triton.cdiv(dim, META["BLOCK_N"]),
)
if batch_ptr.device != x.device:
batch_ptr = batch_ptr.to(x.device)
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
_causal_conv1d_fwd_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_states,
cache_indices,
has_initial_state,
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
out,
# Matrix dimensions
padded_batch,
dim,
cu_seqlen,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
HAS_CACHE=conv_states is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
DECODE_SEQLEN=1,
#launch_cooperative_grid=True
BLOCK_M=8,
BLOCK_N=256,
num_stages=2,
)
return out
def causal_conv1d_update(x: torch.Tensor, @triton.jit()
conv_state: torch.Tensor, def _causal_conv1d_update_kernel(
weight: torch.Tensor, # Pointers to matrices
bias: Optional[torch.Tensor] = None, x_ptr, # (batch, dim, seqlen)
activation: Optional[str] = None, w_ptr, # (dim, width)
cache_seqlens: Optional[torch.Tensor] = None, bias_ptr,
conv_state_indices: Optional[torch.Tensor] = None, conv_state_ptr,
pad_slot_id: int = PAD_SLOT_ID): cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
dim: tl.constexpr,
seqlen: tl.constexpr,
state_len: tl.constexpr,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# ruff: noqa: E501
idx_seq = tl.program_id(0)
if idx_seq >= batch:
return
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 3:
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
) # [BLOCK_N]
x_ptrs = x_base[None, :] + (
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens - VAL >= 0)[:, None] &
(idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
conv_state_ptrs_target = conv_state_base + (
idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
# STEP 3: init accumulator
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim
# STEP 5: compute each token
for idx_token in tl.static_range(seqlen):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < seqlen) & (idx_feats < dim
) # token-index # feature-index
o_ptrs = o_ptr + (
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
idx_feats * stride_o_dim)
tl.store(o_ptrs, acc, mask=mask_1d)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
""" """
x: (batch, dim) or (batch, dim, seqlen) x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1 [shape=2: single token prediction]
[shape=3: single or multiple tokens prediction]
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width) weight: (dim, width)
bias: (dim,) bias: (dim,)
cache_seqlens: (batch,), dtype int32. cache_seqlens: (batch,), dtype int32.
@ -92,14 +849,98 @@ def causal_conv1d_update(x: torch.Tensor,
indices 0 and 3 indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) out: (batch, dim) or (batch, dim, seqlen)
""" """
if activation not in [None, "silu", "swish"]: if validate_data:
raise NotImplementedError("activation must be None, silu, or swish") assert cache_seqlens is None # not implemented yet - ok for vLLM
activation_val = activation in ["silu", "swish"] assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
unsqueeze = x.dim() == 2 unsqueeze = x.dim() == 2
if unsqueeze: if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1) x = x.unsqueeze(-1)
ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, batch, dim, seqlen = x.shape
cache_seqlens, conv_state_indices, pad_slot_id) _, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(
-2
) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch, ) == conv_state_indices.shape
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
stride_w_dim, stride_w_width = weight.stride()
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
) # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
return (
batch,
triton.cdiv(dim, META["BLOCK_N"]),
)
_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
out,
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,
)
if unsqueeze: if unsqueeze:
x = x.squeeze(-1) out = out.squeeze(-1)
return x return out

View File

@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache):
# Initialize parent class # Initialize parent class
super().__init__(max_batch_size) super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape, (conv_state_shape[1], conv_state_shape[0]),
dtype=dtype, dtype=dtype,
device="cuda") device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape, temporal_state_shape,
dtype=dtype, dtype=dtype,

View File

@ -1,14 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.kv_cache_interface import MambaSpec
@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
return chunk_sizes.pop() return chunk_sizes.pop()
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
class Mamba2AttentionBackend(AttentionBackend): class Mamba2AttentionBackend(AttentionBackend):
@staticmethod @staticmethod
@ -53,6 +88,10 @@ class Mamba2AttentionMetadata:
chunk_offsets: torch.Tensor chunk_offsets: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,] state_indices_tensor: torch.Tensor # shape: [batch,]
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(