From 097978a15dc3757e6fdbbae6ad752b97691581bd Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 22 Dec 2025 10:39:22 +0800 Subject: [PATCH] [Kernel] Enable fused_qknorm_rope_kernel supports partial rope (#30821) Signed-off-by: Jee Jee Li --- csrc/fused_qknorm_rope_kernel.cu | 109 ++++++++++-------- tests/kernels/core/test_fused_qk_norm_rope.py | 7 +- 2 files changed, 64 insertions(+), 52 deletions(-) diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu index baff8363162ef..5c23a90794594 100644 --- a/csrc/fused_qknorm_rope_kernel.cu +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -107,7 +107,8 @@ __global__ void fusedQKNormRopeKernel( void const* k_weight_void, // RMSNorm weights for key void const* cos_sin_cache_void, // Pre-computed cos/sin cache int64_t const* position_ids, // Position IDs for RoPE - int const num_tokens // Number of tokens + int const num_tokens, // Number of tokens + int const rotary_dim // Dimension for RoPE ) { #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) if constexpr ((std::is_same_v) || @@ -227,56 +228,59 @@ __global__ void fusedQKNormRopeKernel( // Calculate cache pointer for this position - similar to // pos_encoding_kernels.cu - T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim; - int const embed_dim = head_dim / 2; + T_cache const* cache_ptr = cos_sin_cache + pos_id * rotary_dim; + int const embed_dim = rotary_dim / 2; T_cache const* cos_ptr = cache_ptr; T_cache const* sin_ptr = cache_ptr + embed_dim; - - if constexpr (interleave) { - // Perform interleaving. Use pre-computed cos/sin values. + int const rotary_lanes = rotary_dim / numElemsPerThread; // rotary range + if (laneId < rotary_lanes) { + if constexpr (interleave) { + // Perform interleaving. Use pre-computed cos/sin values. #pragma unroll - for (int i = 0; i < numElemsPerThread / 2; ++i) { - int const idx0 = 2 * i; - int const idx1 = 2 * i + 1; + for (int i = 0; i < numElemsPerThread / 2; ++i) { + int const idx0 = 2 * i; + int const idx1 = 2 * i + 1; + // Global dimension index in the head + int const dim_idx = laneId * numElemsPerThread + idx0; - float const val0 = elements[idx0]; - float const val1 = elements[idx1]; + float const val0 = elements[idx0]; + float const val1 = elements[idx1]; - int const dim_idx = laneId * numElemsPerThread + idx0; - int const half_dim = dim_idx / 2; - float const cos_val = - CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); - float const sin_val = - CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + int const half_dim = dim_idx / 2; + float const cos_val = + CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float const sin_val = + CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); - elements[idx0] = val0 * cos_val - val1 * sin_val; - elements[idx1] = val0 * sin_val + val1 * cos_val; - } - } else { - // Before data exchange with in warp, we need to sync. - __syncwarp(); - // Get the data from the other half of the warp. Use pre-computed cos/sin - // values. -#pragma unroll - for (int i = 0; i < numElemsPerThread; i++) { - elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16); - if (laneId < 16) { - elements2[i] = -elements2[i]; + elements[idx0] = val0 * cos_val - val1 * sin_val; + elements[idx1] = val0 * sin_val + val1 * cos_val; } + } else { + // Before data exchange with in warp, we need to sync. + __syncwarp(); + int pairOffset = (rotary_dim / 2) / numElemsPerThread; + // Get the data from the other half of the warp. Use pre-computed + // cos/sin values. +#pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], pairOffset); - int dim_idx = laneId * numElemsPerThread + i; - dim_idx = (dim_idx * 2) % head_dim; - int half_dim = dim_idx / 2; - // Use pre-computed cos/sin from cache - float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); - float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + if (laneId < pairOffset) { + elements2[i] = -elements2[i]; + } + int dim_idx = laneId * numElemsPerThread + i; - elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + dim_idx = (dim_idx * 2) % rotary_dim; + int half_dim = dim_idx / 2; + float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + } + // __shfl_xor_sync does not provide memfence. Need to sync again. + __syncwarp(); } - // __shfl_xor_sync does not provide memfence. Need to sync again. - __syncwarp(); } - // Store. { vec_T vec; @@ -312,10 +316,10 @@ template void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k, int const num_heads_v, int const head_dim, - float const eps, void const* q_weight, - void const* k_weight, void const* cos_sin_cache, - bool const interleave, int64_t const* position_ids, - cudaStream_t stream) { + int const rotary_dim, float const eps, + void const* q_weight, void const* k_weight, + void const* cos_sin_cache, bool const interleave, + int64_t const* position_ids, cudaStream_t stream) { constexpr int blockSize = 256; int const warpsPerBlock = blockSize / 32; @@ -332,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, fusedQKNormRopeKernel <<>>( qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, - k_weight, cos_sin_cache, position_ids, num_tokens); + k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim); }); break; case 128: @@ -340,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, fusedQKNormRopeKernel <<>>( qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, - k_weight, cos_sin_cache, position_ids, num_tokens); + k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim); }); break; case 256: @@ -348,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, fusedQKNormRopeKernel <<>>( qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, - k_weight, cos_sin_cache, position_ids, num_tokens); + k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim); }); break; default: @@ -392,12 +396,16 @@ void fused_qk_norm_rope( "Query weights size must match head dimension"); TORCH_CHECK(k_weight.size(0) == head_dim, "Key weights size must match head dimension"); - TORCH_CHECK(cos_sin_cache.size(1) == head_dim, - "Cos/sin cache dimension must match head_dim"); + + TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even"); + TORCH_CHECK(cos_sin_cache.size(1) <= head_dim, + "rotary_dim must be less than or equal to head_dim"); + TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() && qkv.scalar_type() == k_weight.scalar_type(), "qkv, q_weight and k_weight must have the same dtype"); + int64_t rotary_dim = cos_sin_cache.size(1); int64_t num_tokens = qkv.size(0); TORCH_CHECK(position_ids.size(0) == num_tokens, "Number of tokens in position_ids must match QKV"); @@ -419,7 +427,8 @@ void fused_qk_norm_rope( qkv.data_ptr(), static_cast(num_tokens), static_cast(num_heads_q), static_cast(num_heads_k), static_cast(num_heads_v), static_cast(head_dim), - static_cast(eps), q_weight.data_ptr(), k_weight.data_ptr(), + static_cast(cos_sin_cache.size(1)), static_cast(eps), + q_weight.data_ptr(), k_weight.data_ptr(), cos_sin_cache.data_ptr(), !is_neox, reinterpret_cast(position_ids.data_ptr()), stream); diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py index a23959e353da9..05d61ec02fd29 100644 --- a/tests/kernels/core/test_fused_qk_norm_rope.py +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -13,6 +13,7 @@ DTYPES = [torch.bfloat16, torch.float16] IS_NEOX = [True, False] EPS_VALUES = [1e-5, 1e-6] SEEDS = [13] +PARTIAL_ROPE = [True, False] CUDA_DEVICES = ["cuda:0"] @@ -52,6 +53,7 @@ def _apply_qk_norm_rope( @pytest.mark.parametrize("is_neox", IS_NEOX) @pytest.mark.parametrize("eps", EPS_VALUES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25]) @torch.inference_mode() def test_fused_qk_norm_rope_matches_reference( device: str, @@ -59,6 +61,7 @@ def test_fused_qk_norm_rope_matches_reference( is_neox: bool, eps: float, seed: int, + rotary_ratio: float, ): torch.set_default_device(device) current_platform.seed_everything(seed) @@ -76,10 +79,10 @@ def test_fused_qk_norm_rope_matches_reference( k_norm.weight.data.normal_(mean=1.0, std=0.1) q_weight = q_norm.weight.data k_weight = k_norm.weight.data - + rotary_dim = int(head_dim * rotary_ratio) rope = RotaryEmbedding( head_size=head_dim, - rotary_dim=head_dim, + rotary_dim=rotary_dim, max_position_embeddings=4096, base=10000.0, is_neox_style=is_neox,