mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 19:55:42 +08:00
[Kernel] Triton implementation of causal-conv1d for Mamba-based models (#18218)
Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
31b96d1c64
commit
47043eb678
@ -232,7 +232,6 @@ endif()
|
|||||||
|
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
|
||||||
"csrc/cache_kernels.cu"
|
"csrc/cache_kernels.cu"
|
||||||
"csrc/attention/paged_attention_v1.cu"
|
"csrc/attention/paged_attention_v1.cu"
|
||||||
"csrc/attention/paged_attention_v2.cu"
|
"csrc/attention/paged_attention_v2.cu"
|
||||||
|
|||||||
@ -1,656 +0,0 @@
|
|||||||
// clang-format off
|
|
||||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
|
||||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
|
||||||
#include <torch/all.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include "causal_conv1d.h"
|
|
||||||
#include <c10/util/BFloat16.h>
|
|
||||||
#include <c10/util/Half.h>
|
|
||||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
|
||||||
|
|
||||||
#include <cub/block/block_load.cuh>
|
|
||||||
#include <cub/block/block_store.cuh>
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
namespace cub = hipcub;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "static_switch.h"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
||||||
|
|
||||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
|
||||||
if (ITYPE == at::ScalarType::Half) { \
|
|
||||||
using input_t = at::Half; \
|
|
||||||
using weight_t = at::Half; \
|
|
||||||
__VA_ARGS__(); \
|
|
||||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
|
||||||
using input_t = at::BFloat16; \
|
|
||||||
using weight_t = at::BFloat16; \
|
|
||||||
__VA_ARGS__(); \
|
|
||||||
} else if (ITYPE == at::ScalarType::Float) { \
|
|
||||||
using input_t = float; \
|
|
||||||
using weight_t = float; \
|
|
||||||
__VA_ARGS__(); \
|
|
||||||
} else { \
|
|
||||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
|
||||||
// sizes
|
|
||||||
const size_t batch,
|
|
||||||
const size_t dim,
|
|
||||||
const size_t seqlen,
|
|
||||||
const size_t width,
|
|
||||||
// device pointers
|
|
||||||
const at::Tensor x,
|
|
||||||
const at::Tensor weight,
|
|
||||||
const at::Tensor out,
|
|
||||||
const std::optional<at::Tensor>& bias,
|
|
||||||
bool silu_activation,
|
|
||||||
int64_t pad_slot_id,
|
|
||||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
|
||||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
|
||||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
|
||||||
|
|
||||||
// Reset the parameters
|
|
||||||
memset(¶ms, 0, sizeof(params));
|
|
||||||
|
|
||||||
params.batch = batch;
|
|
||||||
params.dim = dim;
|
|
||||||
params.seqlen = seqlen;
|
|
||||||
params.width = width;
|
|
||||||
params.pad_slot_id = pad_slot_id;
|
|
||||||
|
|
||||||
params.silu_activation = silu_activation;
|
|
||||||
|
|
||||||
// Set the pointers and strides.
|
|
||||||
params.x_ptr = x.data_ptr();
|
|
||||||
params.weight_ptr = weight.data_ptr();
|
|
||||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
|
||||||
params.out_ptr = out.data_ptr();
|
|
||||||
// All stride are in elements, not bytes.
|
|
||||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
|
||||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
|
||||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
|
||||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
|
||||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
|
||||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
|
||||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
|
||||||
params.weight_c_stride = weight.stride(0);
|
|
||||||
params.weight_width_stride = weight.stride(1);
|
|
||||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
|
||||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
|
||||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
|
||||||
const std::optional<at::Tensor> &bias_,
|
|
||||||
const std::optional<at::Tensor> &conv_states,
|
|
||||||
const std::optional<at::Tensor> &query_start_loc,
|
|
||||||
const std::optional<at::Tensor> &cache_indices,
|
|
||||||
const std::optional<at::Tensor> &has_initial_state,
|
|
||||||
bool silu_activation,
|
|
||||||
// used to identify padding entries if cache_indices provided
|
|
||||||
// in case of padding, the kernel will return early
|
|
||||||
int64_t pad_slot_id) {
|
|
||||||
auto input_type = x.scalar_type();
|
|
||||||
auto weight_type = weight.scalar_type();
|
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
|
||||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
|
||||||
|
|
||||||
TORCH_CHECK(x.is_cuda());
|
|
||||||
TORCH_CHECK(weight.is_cuda());
|
|
||||||
|
|
||||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
|
||||||
const auto sizes = x.sizes();
|
|
||||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
|
||||||
const int dim = varlen ? sizes[0] : sizes[1];
|
|
||||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
|
||||||
const int width = weight.size(-1);
|
|
||||||
if (varlen){
|
|
||||||
CHECK_SHAPE(x, dim, seqlen);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
|
||||||
}
|
|
||||||
CHECK_SHAPE(weight, dim, width);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if (bias_.has_value()) {
|
|
||||||
auto bias = bias_.value();
|
|
||||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
|
||||||
TORCH_CHECK(bias.is_cuda());
|
|
||||||
TORCH_CHECK(bias.stride(-1) == 1);
|
|
||||||
CHECK_SHAPE(bias, dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (has_initial_state.has_value()) {
|
|
||||||
auto has_initial_state_ = has_initial_state.value();
|
|
||||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
|
||||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
|
||||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (query_start_loc.has_value()) {
|
|
||||||
auto query_start_loc_ = query_start_loc.value();
|
|
||||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
|
||||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (cache_indices.has_value()) {
|
|
||||||
auto cache_indices_ = cache_indices.value();
|
|
||||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
|
||||||
TORCH_CHECK(cache_indices_.is_cuda());
|
|
||||||
CHECK_SHAPE(cache_indices_, batch_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor out = x;
|
|
||||||
|
|
||||||
ConvParamsBase params;
|
|
||||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
|
||||||
bias_,
|
|
||||||
silu_activation,
|
|
||||||
pad_slot_id,
|
|
||||||
query_start_loc,
|
|
||||||
cache_indices,
|
|
||||||
has_initial_state
|
|
||||||
);
|
|
||||||
|
|
||||||
if (conv_states.has_value()) {
|
|
||||||
auto conv_states_ = conv_states.value();
|
|
||||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
|
||||||
TORCH_CHECK(conv_states_.is_cuda());
|
|
||||||
params.conv_states_ptr = conv_states_.data_ptr();
|
|
||||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
|
||||||
params.conv_states_c_stride = conv_states_.stride(1);
|
|
||||||
params.conv_states_l_stride = conv_states_.stride(2);
|
|
||||||
} else {
|
|
||||||
params.conv_states_ptr = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
||||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
|
||||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void causal_conv1d_update(const at::Tensor &x,
|
|
||||||
const at::Tensor &conv_state,
|
|
||||||
const at::Tensor &weight,
|
|
||||||
const std::optional<at::Tensor> &bias_,
|
|
||||||
bool silu_activation,
|
|
||||||
const std::optional<at::Tensor> &cache_seqlens_,
|
|
||||||
const std::optional<at::Tensor> &conv_state_indices_,
|
|
||||||
// used to identify padding entries if cache_indices provided
|
|
||||||
// in case of padding, the kernel will return early
|
|
||||||
int64_t pad_slot_id) {
|
|
||||||
auto input_type = x.scalar_type();
|
|
||||||
auto weight_type = weight.scalar_type();
|
|
||||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
|
||||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
|
||||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
|
||||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
|
||||||
|
|
||||||
TORCH_CHECK(x.is_cuda());
|
|
||||||
TORCH_CHECK(conv_state.is_cuda());
|
|
||||||
TORCH_CHECK(weight.is_cuda());
|
|
||||||
|
|
||||||
const auto sizes = x.sizes();
|
|
||||||
const int batch_size = sizes[0];
|
|
||||||
const int dim = sizes[1];
|
|
||||||
const int seqlen = sizes[2];
|
|
||||||
const int width = weight.size(-1);
|
|
||||||
const int conv_state_len = conv_state.size(2);
|
|
||||||
TORCH_CHECK(conv_state_len >= width - 1);
|
|
||||||
|
|
||||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
|
||||||
CHECK_SHAPE(weight, dim, width);
|
|
||||||
|
|
||||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
|
||||||
|
|
||||||
if (bias_.has_value()) {
|
|
||||||
auto bias = bias_.value();
|
|
||||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
|
||||||
TORCH_CHECK(bias.is_cuda());
|
|
||||||
TORCH_CHECK(bias.stride(-1) == 1);
|
|
||||||
CHECK_SHAPE(bias, dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor out = x;
|
|
||||||
|
|
||||||
ConvParamsBase params;
|
|
||||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
|
||||||
bias_,
|
|
||||||
silu_activation,
|
|
||||||
pad_slot_id);
|
|
||||||
params.conv_state_ptr = conv_state.data_ptr();
|
|
||||||
params.conv_state_len = conv_state_len;
|
|
||||||
// All stride are in elements, not bytes.
|
|
||||||
params.conv_state_batch_stride = conv_state.stride(0);
|
|
||||||
params.conv_state_c_stride = conv_state.stride(1);
|
|
||||||
params.conv_state_l_stride = conv_state.stride(2);
|
|
||||||
|
|
||||||
if (cache_seqlens_.has_value()) {
|
|
||||||
auto cache_seqlens = cache_seqlens_.value();
|
|
||||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
|
||||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
|
||||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
|
||||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
|
||||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
|
||||||
} else {
|
|
||||||
params.cache_seqlens = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (conv_state_indices_.has_value()) {
|
|
||||||
auto conv_state_indices = conv_state_indices_.value();
|
|
||||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
|
||||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
|
||||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
|
||||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
|
||||||
|
|
||||||
int conv_state_entries = conv_state.size(0);
|
|
||||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
|
||||||
|
|
||||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
|
||||||
} else {
|
|
||||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
|
||||||
params.conv_state_indices_ptr = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
||||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
|
||||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
|
||||||
struct Causal_conv1d_fwd_kernel_traits {
|
|
||||||
using input_t = input_t_;
|
|
||||||
using weight_t = weight_t_;
|
|
||||||
static constexpr int kNThreads = kNThreads_;
|
|
||||||
static constexpr int kWidth = kWidth_;
|
|
||||||
static constexpr int kNBytes = sizeof(input_t);
|
|
||||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
|
||||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
|
||||||
static_assert(kWidth <= kNElts);
|
|
||||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
|
||||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
|
||||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
|
||||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
|
||||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
|
||||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
|
||||||
static constexpr int kSmemIOSize = kIsVecLoad
|
|
||||||
? 0
|
|
||||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
|
||||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
|
||||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Ktraits>
|
|
||||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
|
||||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
|
||||||
constexpr int kWidth = Ktraits::kWidth;
|
|
||||||
constexpr int kNThreads = Ktraits::kNThreads;
|
|
||||||
constexpr int kNElts = Ktraits::kNElts;
|
|
||||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
|
||||||
using input_t = typename Ktraits::input_t;
|
|
||||||
using vec_t = typename Ktraits::vec_t;
|
|
||||||
using weight_t = typename Ktraits::weight_t;
|
|
||||||
|
|
||||||
// Shared memory.
|
|
||||||
extern __shared__ char smem_[];
|
|
||||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
|
||||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
|
||||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
|
||||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
|
||||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
|
||||||
|
|
||||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
|
||||||
const int tidx = threadIdx.x;
|
|
||||||
const int batch_id = blockIdx.x;
|
|
||||||
const int channel_id = blockIdx.y;
|
|
||||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
|
||||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
|
||||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
|
||||||
|
|
||||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
|
||||||
+ channel_id * params.x_c_stride;
|
|
||||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
|
||||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
|
||||||
+ channel_id * params.out_c_stride;
|
|
||||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
|
||||||
|
|
||||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
|
||||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
|
||||||
|
|
||||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
|
||||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
|
||||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
|
||||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
|
||||||
if (cache_index == params.pad_slot_id){
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
|
||||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
|
||||||
|
|
||||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
|
||||||
if (tidx == 0) {
|
|
||||||
input_t initial_state[kNElts] = {0};
|
|
||||||
if (has_initial_state) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
|
||||||
}
|
|
||||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
float weight_vals[kWidth];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
|
||||||
|
|
||||||
constexpr int kChunkSize = kNThreads * kNElts;
|
|
||||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
|
||||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
|
||||||
input_t x_vals_load[2 * kNElts] = {0};
|
|
||||||
if constexpr(kIsVecLoad) {
|
|
||||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
|
||||||
} else {
|
|
||||||
__syncthreads();
|
|
||||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
|
||||||
}
|
|
||||||
x += kChunkSize;
|
|
||||||
__syncthreads();
|
|
||||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
|
||||||
// the last elements of the previous chunk.
|
|
||||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
|
||||||
__syncthreads();
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
|
||||||
__syncthreads();
|
|
||||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
|
||||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
|
||||||
|
|
||||||
float x_vals[2 * kNElts];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
|
||||||
|
|
||||||
float out_vals[kNElts];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNElts; ++i) {
|
|
||||||
out_vals[i] = bias_val;
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth; ++w) {
|
|
||||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.silu_activation) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNElts; ++i) {
|
|
||||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
input_t out_vals_store[kNElts];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
|
||||||
if constexpr(kIsVecLoad) {
|
|
||||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
|
||||||
} else {
|
|
||||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
|
||||||
}
|
|
||||||
out += kChunkSize;
|
|
||||||
|
|
||||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
|
||||||
// in case the final state is separated between the last "smem_exchange" and
|
|
||||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
|
||||||
// (which occurs when `final_state_position` is a non-positive index)
|
|
||||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
|
||||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
|
||||||
input_t vals_load[kNElts] = {0};
|
|
||||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
|
||||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
|
||||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < -final_state_position; ++w){
|
|
||||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
|
||||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
|
||||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
|
||||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
|
||||||
conv_states[w] = vals_load[w + final_state_position];
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Final state is stored in the smem_exchange last token slot,
|
|
||||||
// in case seqlen < kWidth, we would need to take the final state from the
|
|
||||||
// initial state which is stored in conv_states
|
|
||||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
|
||||||
// and load it into conv_state accordingly
|
|
||||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
|
||||||
if (conv_states != nullptr && tidx == last_thread) {
|
|
||||||
input_t x_vals_load[kNElts * 2] = {0};
|
|
||||||
// in case we are on the first kWidth tokens
|
|
||||||
if (last_thread == 0 && seqlen < kWidth){
|
|
||||||
// Need to take the initial state
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
|
||||||
const int offset = seqlen - (kWidth - 1);
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){
|
|
||||||
// pad the existing state
|
|
||||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
|
||||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){
|
|
||||||
if (offset + w >= 0)
|
|
||||||
conv_states[w] = x_vals_load[offset + w ];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// in case the final state is in between the threads data
|
|
||||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
|
||||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
|
||||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
|
||||||
// illegal access error on H100.
|
|
||||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
|
||||||
}
|
|
||||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
|
||||||
#pragma unroll
|
|
||||||
for (int w = 0; w < kWidth - 1; ++w){
|
|
||||||
conv_states[w] = x_vals_load[offset + w ];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
|
||||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
|
||||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
|
||||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
|
||||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
|
||||||
dim3 grid(params.batch, params.dim);
|
|
||||||
|
|
||||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
|
||||||
|
|
||||||
if (kSmemSize >= 48 * 1024) {
|
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
||||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
|
||||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
|
||||||
}
|
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
|
||||||
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
if (params.width == 2) {
|
|
||||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 3) {
|
|
||||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 4) {
|
|
||||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
|
||||||
struct Causal_conv1d_update_kernel_traits {
|
|
||||||
using input_t = input_t_;
|
|
||||||
using weight_t = weight_t_;
|
|
||||||
static constexpr int kNThreads = kNThreads_;
|
|
||||||
static constexpr int kWidth = kWidth_;
|
|
||||||
static constexpr int kNBytes = sizeof(input_t);
|
|
||||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Ktraits, bool kIsCircularBuffer>
|
|
||||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
|
||||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
|
||||||
constexpr int kWidth = Ktraits::kWidth;
|
|
||||||
constexpr int kNThreads = Ktraits::kNThreads;
|
|
||||||
using input_t = typename Ktraits::input_t;
|
|
||||||
using weight_t = typename Ktraits::weight_t;
|
|
||||||
|
|
||||||
const int tidx = threadIdx.x;
|
|
||||||
const int batch_id = blockIdx.x;
|
|
||||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
|
||||||
if (channel_id >= params.dim) return;
|
|
||||||
|
|
||||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
|
||||||
+ channel_id * params.x_c_stride;
|
|
||||||
|
|
||||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
|
||||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
|
||||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
|
||||||
? batch_id
|
|
||||||
: params.conv_state_indices_ptr[batch_id];
|
|
||||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
|
||||||
if (conv_state_batch_coord == params.pad_slot_id){
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
|
||||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
|
||||||
+ channel_id * params.conv_state_c_stride;
|
|
||||||
|
|
||||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
|
||||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
|
||||||
+ channel_id * params.out_c_stride;
|
|
||||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
|
||||||
|
|
||||||
int state_len = params.conv_state_len;
|
|
||||||
int advance_len = params.seqlen;
|
|
||||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
|
||||||
int update_idx = cache_seqlen - (kWidth - 1);
|
|
||||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
|
||||||
|
|
||||||
float weight_vals[kWidth] = {0};
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
|
||||||
|
|
||||||
float x_vals[kWidth] = {0};
|
|
||||||
if constexpr (!kIsCircularBuffer) {
|
|
||||||
#pragma unroll 2
|
|
||||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
|
||||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth - 1; ++i) {
|
|
||||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
|
||||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
|
||||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
|
||||||
}
|
|
||||||
x_vals[i] = float(state_val);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
|
||||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
|
||||||
x_vals[i] = float(state_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll 2
|
|
||||||
for (int i = 0; i < params.seqlen; ++i) {
|
|
||||||
input_t x_val = x[i * params.x_l_stride];
|
|
||||||
if constexpr (!kIsCircularBuffer) {
|
|
||||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
|
||||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
|
||||||
++update_idx;
|
|
||||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
|
||||||
}
|
|
||||||
x_vals[kWidth - 1] = float(x_val);
|
|
||||||
float out_val = bias_val;
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
|
||||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
|
||||||
out[i * params.out_l_stride] = input_t(out_val);
|
|
||||||
// Shift the input buffer by 1
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
|
||||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
|
||||||
auto kernel = params.cache_seqlens == nullptr
|
|
||||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
|
||||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
|
||||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename input_t, typename weight_t>
|
|
||||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|
||||||
if (params.width == 2) {
|
|
||||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 3) {
|
|
||||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
|
||||||
} else if (params.width == 4) {
|
|
||||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
|
||||||
@ -1,159 +0,0 @@
|
|||||||
/******************************************************************************
|
|
||||||
* Copyright (c) 2024, Tri Dao.
|
|
||||||
******************************************************************************/
|
|
||||||
// clang-format off
|
|
||||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
struct ConvParamsBase {
|
|
||||||
using index_t = uint32_t;
|
|
||||||
|
|
||||||
int batch, dim, seqlen, width;
|
|
||||||
int64_t pad_slot_id;
|
|
||||||
bool silu_activation;
|
|
||||||
|
|
||||||
index_t x_batch_stride;
|
|
||||||
index_t x_c_stride;
|
|
||||||
index_t x_l_stride;
|
|
||||||
index_t weight_c_stride;
|
|
||||||
index_t weight_width_stride;
|
|
||||||
index_t out_batch_stride;
|
|
||||||
index_t out_c_stride;
|
|
||||||
index_t out_l_stride;
|
|
||||||
|
|
||||||
int conv_state_len;
|
|
||||||
index_t conv_state_batch_stride;
|
|
||||||
index_t conv_state_c_stride;
|
|
||||||
index_t conv_state_l_stride;
|
|
||||||
|
|
||||||
// Common data pointers.
|
|
||||||
void *__restrict__ x_ptr;
|
|
||||||
void *__restrict__ weight_ptr;
|
|
||||||
void *__restrict__ bias_ptr;
|
|
||||||
void *__restrict__ out_ptr;
|
|
||||||
|
|
||||||
void *__restrict__ conv_state_ptr;
|
|
||||||
void *__restrict__ query_start_loc_ptr;
|
|
||||||
void *__restrict__ has_initial_state_ptr;
|
|
||||||
void *__restrict__ cache_indices_ptr;
|
|
||||||
int32_t *__restrict__ cache_seqlens;
|
|
||||||
|
|
||||||
// For the continuous batching case. Makes it so that the mamba state for
|
|
||||||
// the current batch doesn't need to be a contiguous tensor.
|
|
||||||
int32_t *__restrict__ conv_state_indices_ptr;
|
|
||||||
|
|
||||||
void *__restrict__ seq_idx_ptr;
|
|
||||||
|
|
||||||
// No __restrict__ since initial_states could be the same as final_states.
|
|
||||||
void * initial_states_ptr;
|
|
||||||
index_t initial_states_batch_stride;
|
|
||||||
index_t initial_states_l_stride;
|
|
||||||
index_t initial_states_c_stride;
|
|
||||||
|
|
||||||
void * final_states_ptr;
|
|
||||||
index_t final_states_batch_stride;
|
|
||||||
index_t final_states_l_stride;
|
|
||||||
index_t final_states_c_stride;
|
|
||||||
|
|
||||||
void * conv_states_ptr;
|
|
||||||
index_t conv_states_batch_stride;
|
|
||||||
index_t conv_states_l_stride;
|
|
||||||
index_t conv_states_c_stride;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__device__ inline T shuffle_xor(T val, int offset) {
|
|
||||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
|
||||||
{
|
|
||||||
return std::max(ilist);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
constexpr T constexpr_min(T a, T b) {
|
|
||||||
return std::min(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
#include <hip/hip_bf16.h>
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__device__ inline T shuffle_xor(T val, int offset) {
|
|
||||||
return __shfl_xor(val, offset);
|
|
||||||
}
|
|
||||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
|
||||||
{
|
|
||||||
return *std::max_element(ilist.begin(), ilist.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
constexpr T constexpr_min(T a, T b) {
|
|
||||||
return a < b ? a : b;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<int BYTES> struct BytesToType {};
|
|
||||||
|
|
||||||
template<> struct BytesToType<16> {
|
|
||||||
using Type = uint4;
|
|
||||||
static_assert(sizeof(Type) == 16);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<8> {
|
|
||||||
using Type = uint64_t;
|
|
||||||
static_assert(sizeof(Type) == 8);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<4> {
|
|
||||||
using Type = uint32_t;
|
|
||||||
static_assert(sizeof(Type) == 4);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<2> {
|
|
||||||
using Type = uint16_t;
|
|
||||||
static_assert(sizeof(Type) == 2);
|
|
||||||
};
|
|
||||||
|
|
||||||
template<> struct BytesToType<1> {
|
|
||||||
using Type = uint8_t;
|
|
||||||
static_assert(sizeof(Type) == 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct SumOp {
|
|
||||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<int THREADS>
|
|
||||||
struct Allreduce {
|
|
||||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
|
||||||
template<typename T, typename Operator>
|
|
||||||
static __device__ inline T run(T x, Operator &op) {
|
|
||||||
constexpr int OFFSET = THREADS / 2;
|
|
||||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
|
||||||
return Allreduce<OFFSET>::run(x, op);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct Allreduce<2> {
|
|
||||||
template<typename T, typename Operator>
|
|
||||||
static __device__ inline T run(T x, Operator &op) {
|
|
||||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
// Inspired by
|
|
||||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
|
||||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
|
||||||
// clang-format off
|
|
||||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
/// @param COND - a boolean expression to switch by
|
|
||||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
|
||||||
/// @param ... - code to execute for true and false
|
|
||||||
///
|
|
||||||
/// Usage:
|
|
||||||
/// ```
|
|
||||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
|
||||||
/// some_function<BoolConst>(...);
|
|
||||||
/// });
|
|
||||||
/// ```
|
|
||||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
|
||||||
[&] { \
|
|
||||||
if (COND) { \
|
|
||||||
static constexpr bool CONST_NAME = true; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
} else { \
|
|
||||||
static constexpr bool CONST_NAME = false; \
|
|
||||||
return __VA_ARGS__(); \
|
|
||||||
} \
|
|
||||||
}()
|
|
||||||
16
csrc/ops.h
16
csrc/ops.h
@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
|||||||
const std::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||||
|
|
||||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
|
||||||
const at::Tensor& weight,
|
|
||||||
const std::optional<at::Tensor>& bias_,
|
|
||||||
bool silu_activation,
|
|
||||||
const std::optional<at::Tensor>& cache_seqlens_,
|
|
||||||
const std::optional<at::Tensor>& conv_state_indices_,
|
|
||||||
int64_t pad_slot_id);
|
|
||||||
|
|
||||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
|
||||||
const std::optional<at::Tensor>& bias_,
|
|
||||||
const std::optional<at::Tensor>& conv_states,
|
|
||||||
const std::optional<at::Tensor>& query_start_loc,
|
|
||||||
const std::optional<at::Tensor>& cache_indices,
|
|
||||||
const std::optional<at::Tensor>& has_initial_state,
|
|
||||||
bool silu_activation, int64_t pad_slot_id);
|
|
||||||
|
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||||
torch::Tensor& rank_data, int64_t rank,
|
torch::Tensor& rank_data, int64_t rank,
|
||||||
|
|||||||
@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int pad_slot_id) -> ()");
|
"int pad_slot_id) -> ()");
|
||||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"causal_conv1d_update(Tensor! x,"
|
|
||||||
"Tensor! conv_state,"
|
|
||||||
"Tensor! weight,"
|
|
||||||
"Tensor? bias_,"
|
|
||||||
"bool silu_activation,"
|
|
||||||
"Tensor? cache_seqlens_,"
|
|
||||||
"Tensor? conv_state_indices,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
|
||||||
"Tensor? bias_,"
|
|
||||||
"Tensor!? conv_states,"
|
|
||||||
"Tensor? query_start_loc,"
|
|
||||||
"Tensor? cache_indices,"
|
|
||||||
"Tensor? has_initial_state,"
|
|
||||||
"bool silu_activation,"
|
|
||||||
"int pad_slot_id) -> ()");
|
|
||||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||||
ops.def(
|
ops.def(
|
||||||
|
|||||||
@ -6,9 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from tests.kernels.utils import opcheck
|
|
||||||
from vllm import _custom_ops as ops # noqa: F401
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
|||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
bias = bias.contiguous() if bias is not None else None
|
bias = bias.contiguous() if bias is not None else None
|
||||||
|
|
||||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
|
||||||
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
|
|
||||||
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
|
||||||
@pytest.mark.parametrize("silu_activation", [True])
|
|
||||||
@pytest.mark.parametrize("has_bias", [True])
|
|
||||||
@pytest.mark.parametrize("has_initial_state", [True, False])
|
|
||||||
@pytest.mark.parametrize("width", [4])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
|
||||||
@pytest.mark.parametrize('dim', [64])
|
|
||||||
@pytest.mark.parametrize('batch', [1])
|
|
||||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
|
||||||
has_initial_state, itype):
|
|
||||||
device = "cuda"
|
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
|
||||||
if itype == torch.bfloat16:
|
|
||||||
rtol, atol = 1e-2, 5e-2
|
|
||||||
# set seed
|
|
||||||
current_platform.seed_everything(0)
|
|
||||||
x = torch.randn(batch, dim, seqlen, device=device,
|
|
||||||
dtype=itype).contiguous()
|
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
|
||||||
if has_initial_state:
|
|
||||||
initial_states = torch.randn(batch,
|
|
||||||
dim,
|
|
||||||
width - 1,
|
|
||||||
device=device,
|
|
||||||
dtype=itype)
|
|
||||||
has_initial_state_tensor = torch.ones(batch,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=x.device)
|
|
||||||
else:
|
|
||||||
initial_states = None
|
|
||||||
has_initial_state_tensor = None
|
|
||||||
x_ref = x.clone()
|
|
||||||
weight_ref = weight.clone()
|
|
||||||
bias_ref = bias.clone() if bias is not None else None
|
|
||||||
initial_states_ref = initial_states.clone(
|
|
||||||
) if initial_states is not None else None
|
|
||||||
activation = None if not silu_activation else "silu"
|
|
||||||
out = causal_conv1d_fn(x,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
activation=activation,
|
|
||||||
conv_states=initial_states,
|
|
||||||
has_initial_state=has_initial_state_tensor)
|
|
||||||
out_ref, final_states_ref = causal_conv1d_ref(
|
|
||||||
x_ref,
|
|
||||||
weight_ref,
|
|
||||||
bias_ref,
|
|
||||||
initial_states=initial_states_ref,
|
|
||||||
return_final_states=True,
|
|
||||||
activation=activation)
|
|
||||||
if has_initial_state:
|
|
||||||
assert initial_states is not None and final_states_ref is not None
|
|
||||||
assert torch.allclose(initial_states,
|
|
||||||
final_states_ref,
|
|
||||||
rtol=rtol,
|
|
||||||
atol=atol)
|
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
causal_conv1d_opcheck_fn(x,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
activation=activation,
|
|
||||||
conv_states=initial_states,
|
|
||||||
has_initial_state=has_initial_state_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||||
@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
|||||||
assert torch.equal(conv_state, conv_state_ref)
|
assert torch.equal(conv_state, conv_state_ref)
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
opcheck(torch.ops._C.causal_conv1d_update,
|
|
||||||
(x, conv_state, weight, bias, activation
|
|
||||||
in ["silu", "swish"], None, None, PAD_SLOT_ID))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype",
|
@pytest.mark.parametrize("itype",
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||||
@pytest.mark.parametrize("has_bias", [False, True])
|
@pytest.mark.parametrize("has_bias", [False, True])
|
||||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
@pytest.mark.parametrize("width", [3, 4])
|
||||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||||
# tests correctness in case subset of the sequences are padded
|
# tests correctness in case subset of the sequences are padded
|
||||||
@pytest.mark.parametrize("with_padding", [True, False])
|
@pytest.mark.parametrize("with_padding", [True, False])
|
||||||
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
@pytest.mark.parametrize("batch_size", [3])
|
||||||
seqlen, has_bias,
|
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
||||||
|
width, seqlen, has_bias,
|
||||||
silu_activation, itype):
|
silu_activation, itype):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||||
@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
|||||||
# set seed
|
# set seed
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
batch_size = 3
|
|
||||||
padding = 5 if with_padding else 0
|
padding = 5 if with_padding else 0
|
||||||
padded_batch_size = batch_size + padding
|
padded_batch_size = batch_size + padding
|
||||||
|
# total_entries = number of cache line
|
||||||
total_entries = 10 * batch_size
|
total_entries = 10 * batch_size
|
||||||
|
|
||||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||||
|
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
|
||||||
|
dtype=itype).transpose(1, 2)
|
||||||
|
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
|
|
||||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||||
@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
|||||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||||
],
|
],
|
||||||
dim=0)
|
dim=0)
|
||||||
|
|
||||||
|
# conv_state will be (cache_lines, dim, state_len)
|
||||||
|
# with contiguous along dim-axis
|
||||||
conv_state = torch.randn(total_entries,
|
conv_state = torch.randn(total_entries,
|
||||||
dim,
|
|
||||||
width - 1,
|
width - 1,
|
||||||
|
dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=itype)
|
dtype=itype).transpose(1, 2)
|
||||||
|
|
||||||
conv_state_for_padding_test = conv_state.clone()
|
conv_state_for_padding_test = conv_state.clone()
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||||
activation = None if not silu_activation else "silu"
|
activation = None if not silu_activation else "silu"
|
||||||
|
|
||||||
out = causal_conv1d_update(x,
|
out = causal_conv1d_update(x,
|
||||||
conv_state,
|
conv_state,
|
||||||
weight,
|
weight,
|
||||||
@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
|||||||
activation=activation)
|
activation=activation)
|
||||||
|
|
||||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
|
||||||
assert torch.equal(conv_state[unused_states_bool],
|
assert torch.equal(conv_state[unused_states_bool],
|
||||||
conv_state_for_padding_test[unused_states_bool])
|
conv_state_for_padding_test[unused_states_bool])
|
||||||
|
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||||
opcheck(torch.ops._C.causal_conv1d_update,
|
|
||||||
(x, conv_state, weight, bias, activation
|
|
||||||
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("silu_activation", [True])
|
@pytest.mark.parametrize("silu_activation", [True])
|
||||||
@pytest.mark.parametrize("has_bias", [True])
|
@pytest.mark.parametrize("has_bias", [True])
|
||||||
@pytest.mark.parametrize("width", [4])
|
@pytest.mark.parametrize("width", [4])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
|
||||||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
|
|
||||||
@pytest.mark.parametrize('dim', [64, 4096])
|
@pytest.mark.parametrize('dim', [64, 4096])
|
||||||
# tests correctness in case subset of the sequences are padded
|
|
||||||
@pytest.mark.parametrize('with_padding', [True, False])
|
@pytest.mark.parametrize('with_padding', [True, False])
|
||||||
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
@pytest.mark.parametrize('batch', [4, 10])
|
||||||
silu_activation, itype):
|
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
|
||||||
|
has_bias, silu_activation, itype):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||||
@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
# set seed
|
# set seed
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
seqlens = []
|
seqlens = []
|
||||||
batch_size = 4
|
batch_size = batch
|
||||||
if seqlen < 10:
|
|
||||||
batch_size = 1
|
|
||||||
padding = 3 if with_padding else 0
|
padding = 3 if with_padding else 0
|
||||||
padded_batch_size = batch_size + padding
|
padded_batch_size = batch_size + padding
|
||||||
nsplits = padded_batch_size - 1
|
nsplits = padded_batch_size - 1
|
||||||
|
|
||||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||||
|
|
||||||
seqlens.append(
|
seqlens.append(
|
||||||
torch.diff(
|
torch.diff(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||||
dim=0)
|
dim=0)
|
||||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
|
x = rearrange(
|
||||||
dtype=itype)[:, 4096:4096 + dim, :]
|
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||||
|
"b s d -> b d s")[:, 4096:4096 + dim, :]
|
||||||
|
|
||||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||||
|
|
||||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||||
x_ref = x.clone()
|
x_ref = x.clone()
|
||||||
weight_ref = weight.clone()
|
weight_ref = weight.clone()
|
||||||
bias_ref = bias.clone() if bias is not None else None
|
bias_ref = bias.clone() if bias is not None else None
|
||||||
activation = None if not silu_activation else "silu"
|
activation = None if not silu_activation else "silu"
|
||||||
final_states = torch.randn(total_entries,
|
final_states = torch.randn(total_entries,
|
||||||
dim,
|
|
||||||
width - 1,
|
width - 1,
|
||||||
|
dim,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=x.dtype)
|
dtype=x.dtype).transpose(1, 2)
|
||||||
final_states_ref = final_states.clone()
|
final_states_ref = final_states.clone()
|
||||||
has_initial_states = torch.randint(0,
|
has_initial_states = torch.randint(0,
|
||||||
2, (cumsum.shape[0] - 1, ),
|
2, (cumsum.shape[0] - 1, ),
|
||||||
@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
out = causal_conv1d_fn(x.squeeze(0),
|
||||||
|
weight,
|
||||||
|
bias=bias,
|
||||||
|
conv_states=final_states,
|
||||||
|
query_start_loc=cumsum.cuda(),
|
||||||
|
cache_indices=padded_state_indices,
|
||||||
|
has_initial_state=has_initial_states,
|
||||||
|
activation=activation,
|
||||||
|
pad_slot_id=PAD_SLOT_ID)
|
||||||
|
|
||||||
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
|
||||||
padded_state_indices, has_initial_states,
|
|
||||||
final_states, activation, PAD_SLOT_ID)
|
|
||||||
out_ref = []
|
out_ref = []
|
||||||
out_ref_b = []
|
out_ref_b = []
|
||||||
|
|
||||||
@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
|||||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||||
|
|
||||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
|
||||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
|
||||||
assert torch.allclose(final_states[state_indices],
|
assert torch.allclose(final_states[state_indices],
|
||||||
final_states_ref[state_indices],
|
final_states_ref[state_indices],
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
|
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||||
padded_state_indices, has_initial_states,
|
|
||||||
final_states, activation)
|
|
||||||
|
|||||||
@ -6,11 +6,11 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
|
||||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||||
mamba_chunk_scan_combined)
|
mamba_chunk_scan_combined)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
|
_query_start_loc_to_chunk_indices_offsets)
|
||||||
|
|
||||||
# Added by the IBM Team, 2024
|
# Added by the IBM Team, 2024
|
||||||
|
|
||||||
|
|||||||
@ -963,17 +963,17 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
|||||||
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
|
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
|
||||||
out_dtype: torch.dtype, device: torch.device):
|
out_dtype: torch.dtype, device: torch.device):
|
||||||
"""
|
"""
|
||||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||||
the gemms for each combination based on the specified problem sizes.
|
the gemms for each combination based on the specified problem sizes.
|
||||||
|
|
||||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||||
input and expert weights.
|
input and expert weights.
|
||||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||||
each expert begins its computation. The number of tokens
|
each expert begins its computation. The number of tokens
|
||||||
computed with expert E is expert_offsets[E + 1] -
|
computed with expert E is expert_offsets[E + 1] -
|
||||||
expert_offsets[E] And the sf_size per expert is
|
expert_offsets[E] And the sf_size per expert is
|
||||||
sf_offset[E+1] - sf_offset[E]
|
sf_offset[E+1] - sf_offset[E]
|
||||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||||
MMs used in the fused MoE operation.
|
MMs used in the fused MoE operation.
|
||||||
@ -1464,30 +1464,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
|
|||||||
|
|
||||||
|
|
||||||
# mamba
|
# mamba
|
||||||
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
|
||||||
bias_: Optional[torch.Tensor],
|
|
||||||
conv_states: Optional[torch.Tensor],
|
|
||||||
query_start_loc: Optional[torch.Tensor],
|
|
||||||
cache_indices: Optional[torch.Tensor],
|
|
||||||
has_initial_state: Optional[torch.Tensor],
|
|
||||||
silu_activation: bool, pad_slot_id: int):
|
|
||||||
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
|
|
||||||
query_start_loc, cache_indices,
|
|
||||||
has_initial_state, silu_activation,
|
|
||||||
pad_slot_id)
|
|
||||||
|
|
||||||
|
|
||||||
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
|
||||||
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
|
||||||
silu_activation: bool,
|
|
||||||
cache_seqlens: Optional[torch.Tensor],
|
|
||||||
conv_state_indices: Optional[torch.Tensor],
|
|
||||||
pad_slot_id: int):
|
|
||||||
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
|
||||||
silu_activation, cache_seqlens,
|
|
||||||
conv_state_indices, pad_slot_id)
|
|
||||||
|
|
||||||
|
|
||||||
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||||
B: torch.Tensor, C: torch.Tensor,
|
B: torch.Tensor, C: torch.Tensor,
|
||||||
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
||||||
|
|||||||
@ -1,14 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
from vllm.attention.backends.placeholder_attn import (
|
||||||
PlaceholderAttentionMetadata)
|
PlaceholderAttentionMetadata)
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
|
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -21,6 +25,29 @@ class Mamba2Metadata:
|
|||||||
seq_idx: torch.Tensor
|
seq_idx: torch.Tensor
|
||||||
chunk_indices: torch.Tensor
|
chunk_indices: torch.Tensor
|
||||||
chunk_offsets: torch.Tensor
|
chunk_offsets: torch.Tensor
|
||||||
|
"""
|
||||||
|
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||||
|
to handle a request in parallel, two supporting tensors are used
|
||||||
|
(batch_ptr, token_chunk_offset_ptr)
|
||||||
|
BLOCK_M = the # tokens to be handled by a Triton program
|
||||||
|
(can be customized for different hardware)
|
||||||
|
|
||||||
|
nums_dict:
|
||||||
|
tracks the data associated with a given value of BLOCK_M
|
||||||
|
BLOCK_M = #tokens handled by a Triton program
|
||||||
|
cu_seqlen: total tokens per batch
|
||||||
|
(used as flag to update other data at each new input)
|
||||||
|
batch_ptr: tracks batch-id handled by the Triton program
|
||||||
|
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
|
||||||
|
(Triton implementation of causal_conv1d handles parallelism in 3-axes
|
||||||
|
- feature-axis
|
||||||
|
- batch-axis
|
||||||
|
- sequence-axis)
|
||||||
|
"""
|
||||||
|
nums_dict: Optional[dict] = None
|
||||||
|
cu_seqlen: Optional[int] = None
|
||||||
|
batch_ptr: Optional[torch.tensor] = None
|
||||||
|
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||||
|
|
||||||
|
|
||||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||||
@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
|||||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
|
||||||
chunk_size: int,
|
|
||||||
total_seqlens: int):
|
|
||||||
|
|
||||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
|
||||||
|
|
||||||
# outputs will have length expansion of chunks that do not divide
|
|
||||||
# chunk_size
|
|
||||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
|
||||||
> 0).sum()
|
|
||||||
chunk_indices = torch.arange(N,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query_start_loc.device)
|
|
||||||
chunk_offsets = torch.zeros((N, ),
|
|
||||||
dtype=torch.int,
|
|
||||||
device=query_start_loc.device)
|
|
||||||
|
|
||||||
p = 0 # num of insertions
|
|
||||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
|
||||||
|
|
||||||
# if does not divide chunk_size, then there is one chunk insertion
|
|
||||||
p += (s % chunk_size > 0)
|
|
||||||
|
|
||||||
# get the dimensions
|
|
||||||
# - the + 1 for _e is to shift the boundary by one chunk
|
|
||||||
# - this shifting is not needed if chunk_size divides e
|
|
||||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
|
||||||
> 0)
|
|
||||||
|
|
||||||
# adjust inidces and offsets
|
|
||||||
chunk_indices[_s:_e] -= p
|
|
||||||
chunk_offsets[_s] = s % chunk_size
|
|
||||||
|
|
||||||
return chunk_indices, chunk_offsets
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_mamba2_metadata(
|
def prepare_mamba2_metadata(
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
mamba2_metadata=None,
|
||||||
) -> Mamba2Metadata:
|
) -> Mamba2Metadata:
|
||||||
|
|
||||||
# compute number of prefill and decode requests
|
# compute number of prefill and decode requests
|
||||||
@ -96,12 +88,12 @@ def prepare_mamba2_metadata(
|
|||||||
attn_metadata_instances = get_platform_metadata_classes()
|
attn_metadata_instances = get_platform_metadata_classes()
|
||||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||||
and attn_metadata.context_lens_tensor is not None):
|
and attn_metadata.context_lens_tensor is not None):
|
||||||
has_initial_states = \
|
# precompute flag to avoid device syncs later in mamba2 layer
|
||||||
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
# forwards
|
||||||
# precompute flag to avoid device syncs in mamba2 layer forwards
|
|
||||||
# prep is only needed for mamba2 ssd prefill processing
|
# prep is only needed for mamba2 ssd prefill processing
|
||||||
prep_initial_states = torch.any(has_initial_states).item()
|
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||||
|
prep_initial_states = torch.any(
|
||||||
|
has_initial_states[:num_prefills]).item()
|
||||||
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||||
seq_idx = torch.repeat_interleave(torch.arange(
|
seq_idx = torch.repeat_interleave(torch.arange(
|
||||||
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
||||||
@ -117,9 +109,78 @@ def prepare_mamba2_metadata(
|
|||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
query_start_loc, chunk_size, num_prefill_tokens)
|
query_start_loc, chunk_size, num_prefill_tokens)
|
||||||
|
|
||||||
|
if mamba2_metadata is not None:
|
||||||
|
mamba2_metadata.has_initial_states = has_initial_states
|
||||||
|
mamba2_metadata.prep_initial_states = prep_initial_states
|
||||||
|
mamba2_metadata.chunk_size = chunk_size
|
||||||
|
mamba2_metadata.seq_idx = seq_idx
|
||||||
|
mamba2_metadata.chunk_indices = chunk_indices
|
||||||
|
mamba2_metadata.chunk_offsets = chunk_offsets
|
||||||
|
# We use 1 reset flag:
|
||||||
|
# * mamba2_metadata.cu_seqlen is None
|
||||||
|
# update config specific to (each input)
|
||||||
|
# (become available at first layer, e.g. conv_weights)
|
||||||
|
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
|
||||||
|
|
||||||
|
return mamba2_metadata
|
||||||
return Mamba2Metadata(has_initial_states=has_initial_states,
|
return Mamba2Metadata(has_initial_states=has_initial_states,
|
||||||
prep_initial_states=prep_initial_states,
|
prep_initial_states=prep_initial_states,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
seq_idx=seq_idx,
|
seq_idx=seq_idx,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices=chunk_indices,
|
||||||
chunk_offsets=chunk_offsets)
|
chunk_offsets=chunk_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||||
|
mamba2_metadata: Union[Mamba2Metadata,
|
||||||
|
Mamba2AttentionMetadata]):
|
||||||
|
"""
|
||||||
|
this is triggered upon handling a new input at the first layer
|
||||||
|
"""
|
||||||
|
dim, cu_seqlen = x.shape
|
||||||
|
mamba2_metadata.cu_seqlen = cu_seqlen
|
||||||
|
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||||
|
nums_dict = {} # type: ignore
|
||||||
|
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||||
|
nums = -(-seqlens // BLOCK_M)
|
||||||
|
nums_dict[BLOCK_M] = {}
|
||||||
|
nums_dict[BLOCK_M]['nums'] = nums
|
||||||
|
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||||
|
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||||
|
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||||
|
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||||
|
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||||
|
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||||
|
offsetlist = [] # type: ignore
|
||||||
|
for idx, num in enumerate(nums):
|
||||||
|
offsetlist.extend(range(num))
|
||||||
|
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||||
|
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||||
|
|
||||||
|
if mamba2_metadata.batch_ptr is None:
|
||||||
|
# Update default value after class definition
|
||||||
|
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
|
||||||
|
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device='cuda')
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr = torch.full(
|
||||||
|
(MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device='cuda')
|
||||||
|
else:
|
||||||
|
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||||
|
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
|
||||||
|
PAD_SLOT_ID)
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
|
||||||
|
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
|
||||||
|
0:mlist_len].copy_(offsetlist)
|
||||||
|
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
|
||||||
|
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
|
||||||
|
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
|
||||||
|
mamba2_metadata.nums_dict = nums_dict
|
||||||
|
return mamba2_metadata
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class MambaMixer(CustomOp):
|
|||||||
hidden_states = causal_conv1d_fn(
|
hidden_states = causal_conv1d_fn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
conv_states=mamba_cache_params.conv_state,
|
conv_states=mamba_cache_params.conv_state,
|
||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||||
|
|||||||
@ -17,7 +17,8 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
|
update_metadata)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -161,9 +162,9 @@ def mamba_v2_sharded_weight_loader(
|
|||||||
tp_size: int,
|
tp_size: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
) -> LoaderFunction:
|
) -> LoaderFunction:
|
||||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||||
are correctly sharded so that they can be split into x, B, C. It also
|
are correctly sharded so that they can be split into x, B, C. It also
|
||||||
ensures that all the groups corresponding to a head shard is placed
|
ensures that all the groups corresponding to a head shard is placed
|
||||||
together with it.
|
together with it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -458,9 +459,11 @@ class MambaMixer2(CustomOp):
|
|||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
|
mamba2_metadata = attn_metadata
|
||||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
conv_state = self_kv_cache[0]
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states
|
has_initial_states_p = attn_metadata.has_initial_states
|
||||||
@ -531,6 +534,7 @@ class MambaMixer2(CustomOp):
|
|||||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
# Separate prefill and decode by splitting varlen input
|
# Separate prefill and decode by splitting varlen input
|
||||||
# Split along token dimension
|
# Split along token dimension
|
||||||
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||||
hidden_states_B_C,
|
hidden_states_B_C,
|
||||||
@ -579,8 +583,13 @@ class MambaMixer2(CustomOp):
|
|||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
# - "cache_indices" updates the conv_state cache in positions
|
# - "cache_indices" updates the conv_state cache in positions
|
||||||
# pointed to by "state_indices_tensor"
|
# pointed to by "state_indices_tensor"
|
||||||
|
x = hidden_states_B_C_p.transpose(
|
||||||
|
0, 1) # this is the form that causal-conv see
|
||||||
|
if mamba2_metadata.cu_seqlen is None:
|
||||||
|
mamba2_metadata = update_metadata(
|
||||||
|
x, attn_metadata.query_start_loc, mamba2_metadata)
|
||||||
hidden_states_B_C_p = causal_conv1d_fn(
|
hidden_states_B_C_p = causal_conv1d_fn(
|
||||||
hidden_states_B_C_p.transpose(0, 1),
|
x,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
@ -590,8 +599,6 @@ class MambaMixer2(CustomOp):
|
|||||||
query_start_loc=query_start_loc_p).transpose(
|
query_start_loc=query_start_loc_p).transpose(
|
||||||
0, 1)[:num_prefill_tokens]
|
0, 1)[:num_prefill_tokens]
|
||||||
|
|
||||||
# TODO: Why is this needed?
|
|
||||||
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
|
|
||||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||||
hidden_states_B_C_p)
|
hidden_states_B_C_p)
|
||||||
|
|
||||||
@ -715,9 +722,10 @@ class MambaMixer2(CustomOp):
|
|||||||
# - heads and n_groups are TP-ed
|
# - heads and n_groups are TP-ed
|
||||||
conv_dim = (self.intermediate_size +
|
conv_dim = (self.intermediate_size +
|
||||||
2 * n_groups * self.ssm_state_size)
|
2 * n_groups * self.ssm_state_size)
|
||||||
|
# contiguous along 'dim' axis
|
||||||
conv_state_shape = (
|
conv_state_shape = (
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.conv_kernel_size - 1,
|
self.conv_kernel_size - 1,
|
||||||
|
divide(conv_dim, world_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -36,10 +36,12 @@ class MambaCacheManager(ConstantSizeCache):
|
|||||||
# Initialize parent class
|
# Initialize parent class
|
||||||
super().__init__(max_batch_size)
|
super().__init__(max_batch_size)
|
||||||
|
|
||||||
|
# assume conv_state = (dim, state_len)
|
||||||
|
assert conv_state_shape[0] > conv_state_shape[1]
|
||||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
conv_state_shape,
|
(conv_state_shape[1], conv_state_shape[0]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device="cuda")
|
device="cuda").transpose(-1, -2)
|
||||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||||
temporal_state_shape,
|
temporal_state_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
_query_start_loc_to_chunk_indices_offsets)
|
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.kv_cache_interface import MambaSpec
|
from vllm.v1.kv_cache_interface import MambaSpec
|
||||||
@ -29,6 +28,42 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
|
|||||||
return chunk_sizes.pop()
|
return chunk_sizes.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
total_seqlens: int):
|
||||||
|
|
||||||
|
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||||
|
|
||||||
|
# outputs will have length expansion of chunks that do not divide
|
||||||
|
# chunk_size
|
||||||
|
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||||
|
> 0).sum()
|
||||||
|
chunk_indices = torch.arange(N,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
chunk_offsets = torch.zeros((N, ),
|
||||||
|
dtype=torch.int,
|
||||||
|
device=query_start_loc.device)
|
||||||
|
|
||||||
|
p = 0 # num of insertions
|
||||||
|
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||||
|
|
||||||
|
# if does not divide chunk_size, then there is one chunk insertion
|
||||||
|
p += (s % chunk_size > 0)
|
||||||
|
|
||||||
|
# get the dimensions
|
||||||
|
# - the + 1 for _e is to shift the boundary by one chunk
|
||||||
|
# - this shifting is not needed if chunk_size divides e
|
||||||
|
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
|
||||||
|
> 0)
|
||||||
|
|
||||||
|
# adjust indices and offsets
|
||||||
|
chunk_indices[_s:_e] -= p
|
||||||
|
chunk_offsets[_s] = s % chunk_size
|
||||||
|
|
||||||
|
return chunk_indices, chunk_offsets
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionBackend(AttentionBackend):
|
class Mamba2AttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -53,6 +88,10 @@ class Mamba2AttentionMetadata:
|
|||||||
chunk_offsets: torch.Tensor
|
chunk_offsets: torch.Tensor
|
||||||
|
|
||||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||||
|
nums_dict: Optional[dict] = None
|
||||||
|
cu_seqlen: Optional[int] = None
|
||||||
|
batch_ptr: Optional[torch.tensor] = None
|
||||||
|
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user