#pragma once namespace cacheflow { template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; int lane = threadIdx.x % WARP_SIZE; // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } // Warp leaders store the data to shared memory. if (lane == 0) { red_smem[warp] = sum; } // Make sure the data is in shared memory. __syncthreads(); // The warps compute the final sums. if (lane < WARPS_PER_BLOCK) { sum = red_smem[lane]; } // Parallel reduction inside the warp. #pragma unroll for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } // Broadcast to other threads. return __shfl_sync(uint32_t(-1), sum, 0); } #define FINAL_MASK 0xffffffff template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); return val; } /* Calculate the sum of all elements in a block */ template __inline__ __device__ T blockReduceSum(T val) { static __shared__ T shared[32]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; val = warpReduceSum(val); if (lane == 0) shared[wid] = val; __syncthreads(); // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // blockDim.x is not divided by 32 val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); val = warpReduceSum(val); return val; } } // namespace cacheflow