#include #include namespace cacheflow { template __device__ __forceinline__ T silu(const T& x) { // x * sigmoid(x) return (T) (((float) x) / (1.0f + expf((float) -x))); } template __global__ void silu_and_mul_kernel( scalar_t* __restrict__ out, // [num_tokens, d] const scalar_t* __restrict__ input, // [num_tokens, 2, d] const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } } // namespace cacheflow void silu_and_mul( torch::Tensor& out, // [num_tokens, d] torch::Tensor& input) // [num_tokens, 2 * d] { int num_tokens = input.size(0); int d = input.size(1) / 2; dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "silu_and_mul_kernel", [&] { cacheflow::silu_and_mul_kernel<<>>( out.data_ptr(), input.data_ptr(), d); }); }