mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 10:36:32 +08:00
[Kernel] Enable fused_qknorm_rope_kernel supports partial rope (#30821)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
7e065eba59
commit
097978a15d
@ -107,7 +107,8 @@ __global__ void fusedQKNormRopeKernel(
|
||||
void const* k_weight_void, // RMSNorm weights for key
|
||||
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
|
||||
int64_t const* position_ids, // Position IDs for RoPE
|
||||
int const num_tokens // Number of tokens
|
||||
int const num_tokens, // Number of tokens
|
||||
int const rotary_dim // Dimension for RoPE
|
||||
) {
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||
@ -227,56 +228,59 @@ __global__ void fusedQKNormRopeKernel(
|
||||
|
||||
// Calculate cache pointer for this position - similar to
|
||||
// pos_encoding_kernels.cu
|
||||
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
|
||||
int const embed_dim = head_dim / 2;
|
||||
T_cache const* cache_ptr = cos_sin_cache + pos_id * rotary_dim;
|
||||
int const embed_dim = rotary_dim / 2;
|
||||
T_cache const* cos_ptr = cache_ptr;
|
||||
T_cache const* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
int const rotary_lanes = rotary_dim / numElemsPerThread; // rotary range
|
||||
if (laneId < rotary_lanes) {
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
// Global dimension index in the head
|
||||
int const dim_idx = laneId * numElemsPerThread + idx0;
|
||||
|
||||
float const val0 = elements[idx0];
|
||||
float const val1 = elements[idx1];
|
||||
float const val0 = elements[idx0];
|
||||
float const val1 = elements[idx1];
|
||||
|
||||
int const dim_idx = laneId * numElemsPerThread + idx0;
|
||||
int const half_dim = dim_idx / 2;
|
||||
float const cos_val =
|
||||
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float const sin_val =
|
||||
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
int const half_dim = dim_idx / 2;
|
||||
float const cos_val =
|
||||
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float const sin_val =
|
||||
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[idx0] = val0 * cos_val - val1 * sin_val;
|
||||
elements[idx1] = val0 * sin_val + val1 * cos_val;
|
||||
}
|
||||
} else {
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
||||
// values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
|
||||
if (laneId < 16) {
|
||||
elements2[i] = -elements2[i];
|
||||
elements[idx0] = val0 * cos_val - val1 * sin_val;
|
||||
elements[idx1] = val0 * sin_val + val1 * cos_val;
|
||||
}
|
||||
} else {
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
int pairOffset = (rotary_dim / 2) / numElemsPerThread;
|
||||
// Get the data from the other half of the warp. Use pre-computed
|
||||
// cos/sin values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], pairOffset);
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
dim_idx = (dim_idx * 2) % head_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
// Use pre-computed cos/sin from cache
|
||||
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
if (laneId < pairOffset) {
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
|
||||
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
|
||||
dim_idx = (dim_idx * 2) % rotary_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Store.
|
||||
{
|
||||
vec_T vec;
|
||||
@ -312,10 +316,10 @@ template <typename scalar_t_in, typename scalar_t_cache>
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
int const num_heads_q, int const num_heads_k,
|
||||
int const num_heads_v, int const head_dim,
|
||||
float const eps, void const* q_weight,
|
||||
void const* k_weight, void const* cos_sin_cache,
|
||||
bool const interleave, int64_t const* position_ids,
|
||||
cudaStream_t stream) {
|
||||
int const rotary_dim, float const eps,
|
||||
void const* q_weight, void const* k_weight,
|
||||
void const* cos_sin_cache, bool const interleave,
|
||||
int64_t const* position_ids, cudaStream_t stream) {
|
||||
constexpr int blockSize = 256;
|
||||
|
||||
int const warpsPerBlock = blockSize / 32;
|
||||
@ -332,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
|
||||
});
|
||||
break;
|
||||
case 128:
|
||||
@ -340,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
|
||||
});
|
||||
break;
|
||||
case 256:
|
||||
@ -348,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
@ -392,12 +396,16 @@ void fused_qk_norm_rope(
|
||||
"Query weights size must match head dimension");
|
||||
TORCH_CHECK(k_weight.size(0) == head_dim,
|
||||
"Key weights size must match head dimension");
|
||||
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
|
||||
"Cos/sin cache dimension must match head_dim");
|
||||
|
||||
TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even");
|
||||
TORCH_CHECK(cos_sin_cache.size(1) <= head_dim,
|
||||
"rotary_dim must be less than or equal to head_dim");
|
||||
|
||||
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
|
||||
qkv.scalar_type() == k_weight.scalar_type(),
|
||||
"qkv, q_weight and k_weight must have the same dtype");
|
||||
|
||||
int64_t rotary_dim = cos_sin_cache.size(1);
|
||||
int64_t num_tokens = qkv.size(0);
|
||||
TORCH_CHECK(position_ids.size(0) == num_tokens,
|
||||
"Number of tokens in position_ids must match QKV");
|
||||
@ -419,7 +427,8 @@ void fused_qk_norm_rope(
|
||||
qkv.data_ptr(), static_cast<int>(num_tokens),
|
||||
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
|
||||
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
|
||||
static_cast<int>(cos_sin_cache.size(1)), static_cast<float>(eps),
|
||||
q_weight.data_ptr(), k_weight.data_ptr(),
|
||||
cos_sin_cache.data_ptr(), !is_neox,
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
stream);
|
||||
|
||||
@ -13,6 +13,7 @@ DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
PARTIAL_ROPE = [True, False]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@ -52,6 +53,7 @@ def _apply_qk_norm_rope(
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
@ -59,6 +61,7 @@ def test_fused_qk_norm_rope_matches_reference(
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
rotary_ratio: float,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
@ -76,10 +79,10 @@ def test_fused_qk_norm_rope_matches_reference(
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rotary_dim = int(head_dim * rotary_ratio)
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user