mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:36:29 +08:00
699 lines
34 KiB
Plaintext
699 lines
34 KiB
Plaintext
// clang-format off
|
|
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
|
#include <torch/all.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include "selective_scan.h"
|
|
|
|
#include <c10/util/BFloat16.h>
|
|
#include <c10/util/Half.h>
|
|
#ifdef USE_ROCM
|
|
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
|
#else
|
|
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
|
#endif
|
|
|
|
#ifndef USE_ROCM
|
|
#include <cub/block/block_load.cuh>
|
|
#include <cub/block/block_store.cuh>
|
|
#include <cub/block/block_scan.cuh>
|
|
#else
|
|
#include <hipcub/hipcub.hpp>
|
|
namespace cub = hipcub;
|
|
#endif
|
|
|
|
#include "selective_scan.h"
|
|
#include "static_switch.h"
|
|
|
|
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
|
bool kIsVariableB_, bool kIsVariableC_,
|
|
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
|
|
struct Selective_Scan_fwd_kernel_traits {
|
|
static_assert(kNItems_ % 4 == 0);
|
|
using input_t = input_t_;
|
|
using weight_t = weight_t_;
|
|
using state_t = state_t_;
|
|
static constexpr int kNThreads = kNThreads_;
|
|
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
|
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
|
static constexpr int kNItems = kNItems_;
|
|
static constexpr int kNRows = kNRows_;
|
|
static constexpr int kNBytes = sizeof(input_t);
|
|
static_assert(kNBytes == 2 || kNBytes == 4);
|
|
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
|
static_assert(kNItems % kNElts == 0);
|
|
static constexpr int kNLoads = kNItems / kNElts;
|
|
static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_;
|
|
static constexpr bool kIsVariableB = kIsVariableB_;
|
|
static constexpr bool kIsVariableC = kIsVariableC_;
|
|
static constexpr bool kHasZ = kHasZ_;
|
|
static constexpr bool kVarlen = kVarlen_;
|
|
|
|
static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1;
|
|
static constexpr int kNLoadsIndex = kNItems / 4;
|
|
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
|
using scan_t = float2;
|
|
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
|
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
|
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
|
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
|
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
|
|
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
|
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
|
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
|
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
|
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
|
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
|
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
|
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
|
sizeof(typename BlockLoadVecT::TempStorage),
|
|
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
|
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
|
sizeof(typename BlockStoreT::TempStorage),
|
|
sizeof(typename BlockStoreVecT::TempStorage)});
|
|
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
|
};
|
|
|
|
template<typename Ktraits>
|
|
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
|
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
|
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
|
constexpr bool kHasZ = Ktraits::kHasZ;
|
|
constexpr bool kVarlen = Ktraits::kVarlen;
|
|
constexpr int kNThreads = Ktraits::kNThreads;
|
|
constexpr int kNItems = Ktraits::kNItems;
|
|
constexpr int kNRows = Ktraits::kNRows;
|
|
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
|
using input_t = typename Ktraits::input_t;
|
|
using weight_t = typename Ktraits::weight_t;
|
|
using scan_t = typename Ktraits::scan_t;
|
|
|
|
// Shared memory.
|
|
extern __shared__ char smem_[];
|
|
// cast to lvalue reference of expected type
|
|
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
|
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
|
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
|
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
|
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
|
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
|
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
|
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
|
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
|
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
|
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
|
|
|
const int batch_id = blockIdx.x;
|
|
const int dim_id = blockIdx.y;
|
|
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
|
int seqlen = params.seqlen;
|
|
int sequence_start_index = batch_id;
|
|
if constexpr (kVarlen){
|
|
int *query_start_loc = reinterpret_cast<int *>(params.query_start_loc_ptr);
|
|
sequence_start_index = query_start_loc[batch_id];
|
|
seqlen = query_start_loc[batch_id + 1] - sequence_start_index;
|
|
}
|
|
const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
|
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
|
|
|
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
|
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
|
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
|
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
|
if (cache_index == params.pad_slot_id){
|
|
return;
|
|
}
|
|
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
|
+ dim_id * kNRows * params.u_d_stride;
|
|
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
|
|
+ dim_id * kNRows * params.delta_d_stride;
|
|
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
|
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
|
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
|
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
|
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
|
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
|
|
cache_index * params.ssm_states_batch_stride +
|
|
dim_id * kNRows * params.ssm_states_dim_stride;
|
|
|
|
float D_val[kNRows] = {0};
|
|
if (params.D_ptr != nullptr) {
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
|
}
|
|
}
|
|
float delta_bias[kNRows] = {0};
|
|
if (params.delta_bias_ptr != nullptr) {
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
|
}
|
|
}
|
|
|
|
|
|
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
|
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
|
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
|
// }
|
|
|
|
constexpr int kChunkSize = kNThreads * kNItems;
|
|
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
|
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
|
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
|
|
|
__syncthreads();
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
if constexpr (!kDirectIO) {
|
|
if (r > 0) { __syncthreads(); }
|
|
}
|
|
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
|
|
if constexpr (!kDirectIO) { __syncthreads(); }
|
|
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
|
|
}
|
|
u += kChunkSize;
|
|
delta += kChunkSize;
|
|
|
|
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
#pragma unroll
|
|
for (int i = 0; i < kNItems; ++i) {
|
|
float u_val = float(u_vals[r][i]);
|
|
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
|
if (params.delta_softplus) {
|
|
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
|
}
|
|
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
|
out_vals[r][i] = D_val[r] * u_val;
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
|
weight_t A_val[kNRows];
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
|
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
|
constexpr float kLog2e = M_LOG2E;
|
|
A_val[r] *= kLog2e;
|
|
}
|
|
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
|
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
|
// If both B and C vary, this is unused.
|
|
weight_t BC_val[kNRows];
|
|
weight_t B_vals[kNItems], C_vals[kNItems];
|
|
if constexpr (kIsVariableB) {
|
|
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
|
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
|
|
if constexpr (!kIsVariableC) {
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
|
}
|
|
}
|
|
}
|
|
if constexpr (kIsVariableC) {
|
|
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
|
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
|
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
|
|
if constexpr (!kIsVariableB) {
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
|
}
|
|
}
|
|
}
|
|
if constexpr (!kIsVariableB && !kIsVariableC) {
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
|
scan_t thread_data[kNItems];
|
|
#pragma unroll
|
|
for (int i = 0; i < kNItems; ++i) {
|
|
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
|
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
|
|
|
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
|
|
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
|
|
thread_data[i] = make_float2(1.f, 0.f);
|
|
}
|
|
}
|
|
}
|
|
// Initialize running total
|
|
|
|
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
|
|
|
|
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
|
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
|
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
|
);
|
|
// There's a syncthreads in the scan op, so we don't need to sync here.
|
|
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
|
if (threadIdx.x == 0) {
|
|
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
|
if (chunk == n_chunks - 1) {
|
|
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
|
|
}
|
|
}
|
|
#pragma unroll
|
|
for (int i = 0; i < kNItems; ++i) {
|
|
const weight_t C_val = !kIsVariableC
|
|
? BC_val[r]
|
|
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
|
out_vals[r][i] += thread_data[i].y * C_val;
|
|
}
|
|
}
|
|
}
|
|
|
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
|
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
|
__syncthreads();
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
if constexpr (!kDirectIO) {
|
|
if (r > 0) { __syncthreads(); }
|
|
}
|
|
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
|
}
|
|
|
|
if constexpr (kHasZ) {
|
|
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
|
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
|
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
|
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
|
#pragma unroll
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
input_t z_vals[kNItems];
|
|
__syncthreads();
|
|
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
|
|
#pragma unroll
|
|
for (int i = 0; i < kNItems; ++i) {
|
|
float z_val = z_vals[i];
|
|
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
|
}
|
|
__syncthreads();
|
|
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
|
}
|
|
}
|
|
|
|
Bvar += kChunkSize * 1;
|
|
Cvar += kChunkSize * 1;
|
|
}
|
|
}
|
|
|
|
template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
|
|
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
|
// processing 1 row.
|
|
constexpr int kNRows = 1;
|
|
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
|
constexpr bool kIsVariableB = true;
|
|
constexpr bool kIsVariableC = true;
|
|
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
|
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
|
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
|
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t, state_t>;
|
|
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
|
dim3 grid(params.batch, params.dim / kNRows);
|
|
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
|
if (kSmemSize >= 48 * 1024) {
|
|
#ifdef USE_ROCM
|
|
C10_HIP_CHECK(hipFuncSetAttribute(
|
|
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
|
#else
|
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
|
#endif
|
|
}
|
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
});
|
|
});
|
|
});
|
|
}
|
|
|
|
template<typename input_t, typename weight_t, typename state_t>
|
|
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|
|
|
#ifndef USE_ROCM
|
|
if (params.seqlen <= 128) {
|
|
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
|
|
} else if (params.seqlen <= 256) {
|
|
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
|
|
} else if (params.seqlen <= 512) {
|
|
selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
|
|
} else if (params.seqlen <= 1024) {
|
|
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
|
|
} else {
|
|
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
|
|
}
|
|
#else
|
|
if (params.seqlen <= 256) {
|
|
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
|
|
} else if (params.seqlen <= 512) {
|
|
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
|
|
} else if (params.seqlen <= 1024) {
|
|
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
|
|
} else {
|
|
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
|
|
|
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...) \
|
|
if (ITYPE == at::ScalarType::Half) { \
|
|
using input_t = at::Half; \
|
|
using weight_t = float; \
|
|
if (STYPE == at::ScalarType::Half) { \
|
|
using state_t = at::Half; \
|
|
__VA_ARGS__(); \
|
|
} else if (STYPE == at::ScalarType::Float) { \
|
|
using state_t = float; \
|
|
__VA_ARGS__(); \
|
|
} else { \
|
|
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
|
} \
|
|
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
|
using input_t = at::BFloat16; \
|
|
using weight_t = float; \
|
|
if (STYPE == at::ScalarType::BFloat16) { \
|
|
using state_t = at::BFloat16; \
|
|
__VA_ARGS__(); \
|
|
} else if (STYPE == at::ScalarType::Float) { \
|
|
using state_t = float; \
|
|
__VA_ARGS__(); \
|
|
} else { \
|
|
AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
|
|
} \
|
|
} else if (ITYPE == at::ScalarType::Float) { \
|
|
using input_t = float; \
|
|
using weight_t = float; \
|
|
using state_t = float; \
|
|
__VA_ARGS__(); \
|
|
} else { \
|
|
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
|
}
|
|
|
|
|
|
template<typename input_t, typename weight_t, typename state_t>
|
|
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
|
|
|
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|
// sizes
|
|
const size_t batch,
|
|
const size_t dim,
|
|
const size_t seqlen,
|
|
const size_t dstate,
|
|
const size_t n_groups,
|
|
const bool is_variable_B,
|
|
const bool is_variable_C,
|
|
// device pointers
|
|
const torch::Tensor u,
|
|
const torch::Tensor delta,
|
|
const torch::Tensor A,
|
|
const torch::Tensor B,
|
|
const torch::Tensor C,
|
|
const torch::Tensor out,
|
|
const torch::Tensor z,
|
|
const torch::Tensor out_z,
|
|
const std::optional<at::Tensor>& D,
|
|
const std::optional<at::Tensor>& delta_bias,
|
|
const torch::Tensor ssm_states,
|
|
bool has_z,
|
|
bool delta_softplus,
|
|
const std::optional<at::Tensor>& query_start_loc,
|
|
const std::optional<at::Tensor>& cache_indices,
|
|
const std::optional<at::Tensor>& has_initial_state,
|
|
bool varlen,
|
|
int64_t pad_slot_id) {
|
|
|
|
// Reset the parameters
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
params.batch = batch;
|
|
params.dim = dim;
|
|
params.seqlen = seqlen;
|
|
params.dstate = dstate;
|
|
params.n_groups = n_groups;
|
|
params.dim_ngroups_ratio = dim / n_groups;
|
|
params.pad_slot_id = pad_slot_id;
|
|
|
|
params.delta_softplus = delta_softplus;
|
|
|
|
params.is_variable_B = is_variable_B;
|
|
params.is_variable_C = is_variable_C;
|
|
|
|
// Set the pointers and strides.
|
|
params.u_ptr = u.data_ptr();
|
|
params.delta_ptr = delta.data_ptr();
|
|
params.A_ptr = A.data_ptr();
|
|
params.B_ptr = B.data_ptr();
|
|
params.C_ptr = C.data_ptr();
|
|
params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr;
|
|
params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr;
|
|
params.out_ptr = out.data_ptr();
|
|
params.ssm_states_ptr = ssm_states.data_ptr();
|
|
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
|
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
|
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
|
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
|
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
|
|
|
|
|
// All stride are in elements, not bytes.
|
|
params.A_d_stride = A.stride(0);
|
|
params.A_dstate_stride = A.stride(1);
|
|
|
|
if (varlen){
|
|
params.B_batch_stride = B.stride(2);
|
|
params.B_group_stride = B.stride(0);
|
|
params.B_dstate_stride = B.stride(1);
|
|
params.C_batch_stride = C.stride(2);
|
|
params.C_group_stride = C.stride(0);
|
|
params.C_dstate_stride = C.stride(1);
|
|
|
|
params.u_batch_stride = u.stride(1);
|
|
params.u_d_stride = u.stride(0);
|
|
params.delta_batch_stride = delta.stride(1);
|
|
params.delta_d_stride = delta.stride(0);
|
|
if (has_z) {
|
|
params.z_batch_stride = z.stride(1);
|
|
params.z_d_stride = z.stride(0);
|
|
params.out_z_batch_stride = out_z.stride(1);
|
|
params.out_z_d_stride = out_z.stride(0);
|
|
}
|
|
params.out_batch_stride = out.stride(1);
|
|
params.out_d_stride = out.stride(0);
|
|
|
|
params.ssm_states_batch_stride = ssm_states.stride(0);
|
|
params.ssm_states_dim_stride = ssm_states.stride(1);
|
|
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
|
|
|
}
|
|
else{
|
|
if (!is_variable_B) {
|
|
params.B_d_stride = B.stride(0);
|
|
} else {
|
|
params.B_batch_stride = B.stride(0);
|
|
params.B_group_stride = B.stride(1);
|
|
}
|
|
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
|
if (!is_variable_C) {
|
|
params.C_d_stride = C.stride(0);
|
|
} else {
|
|
params.C_batch_stride = C.stride(0);
|
|
params.C_group_stride = C.stride(1);
|
|
}
|
|
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
|
params.u_batch_stride = u.stride(0);
|
|
params.u_d_stride = u.stride(1);
|
|
params.delta_batch_stride = delta.stride(0);
|
|
params.delta_d_stride = delta.stride(1);
|
|
if (has_z) {
|
|
params.z_batch_stride = z.stride(0);
|
|
params.z_d_stride = z.stride(1);
|
|
params.out_z_batch_stride = out_z.stride(0);
|
|
params.out_z_d_stride = out_z.stride(1);
|
|
}
|
|
params.out_batch_stride = out.stride(0);
|
|
params.out_d_stride = out.stride(1);
|
|
|
|
params.ssm_states_batch_stride = ssm_states.stride(0);
|
|
params.ssm_states_dim_stride = ssm_states.stride(1);
|
|
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
|
}
|
|
}
|
|
|
|
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
|
const std::optional<torch::Tensor> &D_,
|
|
const std::optional<torch::Tensor> &z_,
|
|
const std::optional<torch::Tensor> &delta_bias_,
|
|
bool delta_softplus,
|
|
const std::optional<torch::Tensor> &query_start_loc,
|
|
const std::optional<torch::Tensor> &cache_indices,
|
|
const std::optional<torch::Tensor> &has_initial_state,
|
|
const torch::Tensor &ssm_states,
|
|
// used to identify padding entries if cache_indices provided
|
|
// in case of padding, the kernel will return early
|
|
int64_t pad_slot_id) {
|
|
auto input_type = u.scalar_type();
|
|
auto weight_type = A.scalar_type();
|
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
|
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
|
|
|
const bool is_variable_B = B.dim() >= 3;
|
|
const bool is_variable_C = C.dim() >= 3;
|
|
|
|
TORCH_CHECK(delta.scalar_type() == input_type);
|
|
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
|
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
|
|
|
TORCH_CHECK(u.is_cuda());
|
|
TORCH_CHECK(delta.is_cuda());
|
|
TORCH_CHECK(A.is_cuda());
|
|
TORCH_CHECK(B.is_cuda());
|
|
TORCH_CHECK(C.is_cuda());
|
|
|
|
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
|
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
|
|
|
const auto sizes = u.sizes();
|
|
const bool varlen = query_start_loc.has_value();
|
|
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
|
const int dim = varlen ? sizes[0] : sizes[1];
|
|
const int seqlen = varlen ? sizes[1] : sizes[2];
|
|
const int dstate = A.size(1);
|
|
const int n_groups = varlen ? B.size(0) : B.size(1);
|
|
|
|
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
|
|
|
if (varlen) {
|
|
CHECK_SHAPE(u, dim, seqlen);
|
|
CHECK_SHAPE(delta, dim, seqlen);
|
|
} else {
|
|
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
|
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
|
}
|
|
CHECK_SHAPE(A, dim, dstate);
|
|
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
|
if (varlen) {
|
|
CHECK_SHAPE(B, n_groups, dstate, seqlen);
|
|
} else {
|
|
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
|
|
}
|
|
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
|
|
|
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
|
if (varlen) {
|
|
CHECK_SHAPE(C, n_groups, dstate, seqlen);
|
|
} else {
|
|
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
|
}
|
|
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
|
|
|
if (D_.has_value()) {
|
|
auto D = D_.value();
|
|
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
|
TORCH_CHECK(D.is_cuda());
|
|
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
|
CHECK_SHAPE(D, dim);
|
|
}
|
|
|
|
if (delta_bias_.has_value()) {
|
|
auto delta_bias = delta_bias_.value();
|
|
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
|
TORCH_CHECK(delta_bias.is_cuda());
|
|
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
|
CHECK_SHAPE(delta_bias, dim);
|
|
}
|
|
|
|
|
|
if (has_initial_state.has_value()) {
|
|
auto has_initial_state_ = has_initial_state.value();
|
|
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
|
TORCH_CHECK(has_initial_state_.is_cuda());
|
|
CHECK_SHAPE(has_initial_state_, batch_size);
|
|
}
|
|
|
|
|
|
if (query_start_loc.has_value()) {
|
|
auto query_start_loc_ = query_start_loc.value();
|
|
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
|
TORCH_CHECK(query_start_loc_.is_cuda());
|
|
}
|
|
|
|
|
|
if (cache_indices.has_value()) {
|
|
auto cache_indices_ = cache_indices.value();
|
|
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
|
TORCH_CHECK(cache_indices_.is_cuda());
|
|
CHECK_SHAPE(cache_indices_, batch_size);
|
|
}
|
|
|
|
|
|
at::Tensor z, out_z;
|
|
const bool has_z = z_.has_value();
|
|
if (has_z) {
|
|
z = z_.value();
|
|
TORCH_CHECK(z.scalar_type() == input_type);
|
|
TORCH_CHECK(z.is_cuda());
|
|
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
|
if (varlen){
|
|
CHECK_SHAPE(z, dim, seqlen);
|
|
} else {
|
|
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
|
}
|
|
|
|
out_z = z;
|
|
}
|
|
|
|
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
|
at::Tensor out = delta;
|
|
// ssm_states can now be either the same as input_type or float32
|
|
auto state_type = ssm_states.scalar_type();
|
|
TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
|
|
TORCH_CHECK(ssm_states.is_cuda());
|
|
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
|
|
|
SSMParamsBase params;
|
|
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
|
|
u, delta, A, B, C, out, z, out_z,
|
|
D_,
|
|
delta_bias_,
|
|
ssm_states,
|
|
has_z,
|
|
delta_softplus,
|
|
query_start_loc,
|
|
cache_indices,
|
|
has_initial_state,
|
|
varlen,
|
|
pad_slot_id
|
|
);
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
|
|
selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
|
|
});
|
|
}
|