diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 8f24be89578b8..57382c1ddc65b 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -24,7 +24,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" -#include "cuda_compat.h" +#include "../cuda_compat.h" #ifdef USE_ROCM #include diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 7a5ef10f8ef3b..307300e556660 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -16,9 +16,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "attention_kernels.cuh" -#include "cuda_compat.h" +#include "../cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -75,7 +74,7 @@ void paged_attention_v1_launcher( const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index b45b28dad05ea..eb9b4feb4a892 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -16,9 +16,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "attention_kernels.cuh" -#include "cuda_compat.h" +#include "../cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -79,7 +78,7 @@ void paged_attention_v2_launcher( const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index affa051c75951..d7d589db62cfe 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -4,8 +4,35 @@ #include #endif -#if defined(USE_ROCM) && defined(__GFX9__) - #define WARP_SIZE 64 +#ifdef USE_ROCM +struct Utils { + static __host__ int get_warp_size() { + static bool is_cached = false; + static int result; + + if (!is_cached) { + int device_id; + cudaDeviceProp deviceProp; + cudaGetDevice(&device_id); + cudaGetDeviceProperties(&deviceProp, device_id); + + result = deviceProp.warpSize; + is_cached = true; + } + + return result; + } + + static __device__ constexpr int get_warp_size() { + #ifdef __GFX9__ + return 64; + #else + return 32; + #endif + } +}; + + #define WARP_SIZE Utils::get_warp_size() #else #define WARP_SIZE 32 #endif diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 064b76c9cd427..0b505d2e04a21 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template -__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { @@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // Restrictions based on previous section. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); - static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); - static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size"); // We have NUM_EXPERTS elements per row. We specialize for small #experts - static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT; static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; @@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template +template struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); - static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; - static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; + static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW; }; } // namespace detail -template +template void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; + using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; - dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topkGatingSoftmax<<>>( + dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); + topkGatingSoftmax<<>>( input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, \ - stream); +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + switch (warpSize) { \ + case 32: \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + break; \ + case 64: \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \ + } template void topkGatingSoftmaxKernelLauncher( @@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher( const int topk, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; + auto warpSize = WARP_SIZE; switch (num_experts) { case 1: LAUNCH_SOFTMAX(1, WARPS_PER_TB); diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 67e9149c13795..8bc2b9bff3d5a 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -4,7 +4,7 @@ #include #include "core/math.hpp" -#include "cuda_compat.h" +#include "../cuda_compat.h" #include "dispatch_utils.h" #include "quantization/fp8/common.cuh" diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 3b5180b516239..76fe73e950404 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -4,7 +4,7 @@ #include #include -#include "cuda_compat.h" +#include "../../cuda_compat.h" #include "dispatch_utils.h" #include "ggml-common.h" diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 3bddd12cad077..65cb1c1d1478d 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -19,7 +19,7 @@ #include #include #include -#include "cuda_compat.h" +#include "../cuda_compat.h" #include #include "../attention/dtype_fp8.cuh" diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 6212570c79d1f..eb47139208c91 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -9,7 +9,7 @@ #include #include -#include "cuda_compat.h" +#include "../cuda_compat.h" #include "dispatch_utils.h" #include "quantization/fp8/common.cuh"