mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 11:24:26 +08:00
[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
parent
a7ef3eb0cd
commit
68c09efc37
@ -451,6 +451,7 @@ steps:
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
- pytest -v -s compile/test_noop_elimination.py
|
||||
- pytest -v -s compile/test_aot_compile.py
|
||||
- pytest -v -s compile/test_qk_norm_rope_fusion.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
|
||||
@ -265,6 +265,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/fused_qknorm_rope_kernel.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/sampler.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
|
||||
418
csrc/fused_qknorm_rope_kernel.cu
Normal file
418
csrc/fused_qknorm_rope_kernel.cu
Normal file
@ -0,0 +1,418 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cuda_runtime.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "type_convert.cuh"
|
||||
|
||||
#define CHECK_TYPE(x, st) \
|
||||
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
|
||||
", while ", st, " is expected")
|
||||
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_TH_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define FINAL_MASK 0xffffffff
|
||||
|
||||
// TODO: suport for AMD ROCM platform
|
||||
#ifndef USE_ROCM
|
||||
namespace tensorrt_llm::common {
|
||||
template <typename T, int num>
|
||||
struct packed_as;
|
||||
// Specialization for packed_as used in this kernel.
|
||||
template <>
|
||||
struct packed_as<uint, 1> {
|
||||
using type = uint;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_as<uint, 2> {
|
||||
using type = uint2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packed_as<uint, 4> {
|
||||
using type = uint4;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ __host__ T divUp(T m, T n) {
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
namespace tensorrt_llm::kernels {
|
||||
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
|
||||
// with added support for passing the cos_sin_cache as an input.
|
||||
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
|
||||
|
||||
// Perform per-head QK Norm and RoPE in a single kernel.
|
||||
// scalar_t_in: data type of QKV and RMSNorm weights
|
||||
// scalar_t_cache: data type of cos/sin cache
|
||||
// head_dim: the dimension of each head
|
||||
// interleave: interleave=!is_neox.
|
||||
template <typename scalar_t_in, typename scalar_t_cache, int head_dim,
|
||||
bool interleave>
|
||||
__global__ void fusedQKNormRopeKernel(
|
||||
void* qkv_void, // Combined QKV tensor
|
||||
int const num_heads_q, // Number of query heads
|
||||
int const num_heads_k, // Number of key heads
|
||||
int const num_heads_v, // Number of value heads
|
||||
float const eps, // Epsilon for RMS normalization
|
||||
void const* q_weight_void, // RMSNorm weights for query
|
||||
void const* k_weight_void, // RMSNorm weights for key
|
||||
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
|
||||
int64_t const* position_ids, // Position IDs for RoPE
|
||||
int const num_tokens // Number of tokens
|
||||
) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
|
||||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
|
||||
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||
static_assert(Converter::exists,
|
||||
"Input QKV data type is not supported for this CUDA "
|
||||
"architecture or toolkit version.");
|
||||
using T_in = typename Converter::hip_type;
|
||||
using T2_in = typename Converter::packed_hip_type;
|
||||
|
||||
using CacheConverter = vllm::_typeConvert<scalar_t_cache>;
|
||||
static_assert(CacheConverter::exists,
|
||||
"Cache data type is not supported for this CUDA architecture "
|
||||
"or toolkit version.");
|
||||
using T_cache = typename CacheConverter::hip_type;
|
||||
|
||||
T_in* qkv = reinterpret_cast<T_in*>(qkv_void);
|
||||
T_in const* q_weight = reinterpret_cast<T_in const*>(q_weight_void);
|
||||
T_in const* k_weight = reinterpret_cast<T_in const*>(k_weight_void);
|
||||
T_cache const* cos_sin_cache =
|
||||
reinterpret_cast<T_cache const*>(cos_sin_cache_void);
|
||||
|
||||
int const warpsPerBlock = blockDim.x / 32;
|
||||
int const warpId = threadIdx.x / 32;
|
||||
int const laneId = threadIdx.x % 32;
|
||||
|
||||
// Calculate global warp index to determine which head/token this warp
|
||||
// processes
|
||||
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
|
||||
|
||||
// Total number of attention heads (Q and K)
|
||||
int const total_qk_heads = num_heads_q + num_heads_k;
|
||||
|
||||
// Determine which token and head type (Q or K) this warp processes
|
||||
int const tokenIdx = globalWarpIdx / total_qk_heads;
|
||||
int const localHeadIdx = globalWarpIdx % total_qk_heads;
|
||||
|
||||
// Skip if this warp is assigned beyond the number of tokens
|
||||
if (tokenIdx >= num_tokens) return;
|
||||
|
||||
bool const isQ = localHeadIdx < num_heads_q;
|
||||
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
|
||||
|
||||
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
|
||||
|
||||
static_assert(head_dim % (32 * 2) == 0,
|
||||
"head_dim must be divisible by 64 (each warp processes one "
|
||||
"head, and each thread gets even number of "
|
||||
"elements)");
|
||||
constexpr int numElemsPerThread = head_dim / 32;
|
||||
float elements[numElemsPerThread];
|
||||
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
|
||||
static_assert(elemSizeBytes % 4 == 0,
|
||||
"numSizeBytes must be a multiple of 4");
|
||||
constexpr int vecSize =
|
||||
elemSizeBytes /
|
||||
4; // Use packed_as<uint, vecSize> to perform loading/saving.
|
||||
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
|
||||
|
||||
int offsetWarp; // Offset for the warp
|
||||
if (isQ) {
|
||||
// Q segment: token offset + head offset within Q segment
|
||||
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
|
||||
} else {
|
||||
// K segment: token offset + entire Q segment + head offset within K
|
||||
// segment
|
||||
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
|
||||
headIdx * head_dim;
|
||||
}
|
||||
int offsetThread = offsetWarp + laneId * numElemsPerThread;
|
||||
|
||||
// Sum of squares for RMSNorm
|
||||
float sumOfSquares = 0.0f;
|
||||
|
||||
// Load.
|
||||
{
|
||||
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
|
||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_packed_elems; i++) {
|
||||
// Interpret the generic vector chunk as the specific packed type
|
||||
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
|
||||
// Convert to float2 for computation
|
||||
float2 vals = Converter::convert(packed_val);
|
||||
sumOfSquares += vals.x * vals.x;
|
||||
sumOfSquares += vals.y * vals.y;
|
||||
|
||||
elements[2 * i] = vals.x;
|
||||
elements[2 * i + 1] = vals.y;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce sum across warp using the utility function
|
||||
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
|
||||
|
||||
// Compute RMS normalization factor
|
||||
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
|
||||
|
||||
// Normalize elements
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
int dim = laneId * numElemsPerThread + i;
|
||||
float weight = isQ ? Converter::convert(q_weight[dim])
|
||||
: Converter::convert(k_weight[dim]);
|
||||
elements[i] *= rms_rcp * weight;
|
||||
}
|
||||
|
||||
// Apply RoPE to normalized elements
|
||||
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
|
||||
|
||||
int64_t pos_id = position_ids[tokenIdx];
|
||||
|
||||
// Calculate cache pointer for this position - similar to
|
||||
// pos_encoding_kernels.cu
|
||||
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
|
||||
int const embed_dim = head_dim / 2;
|
||||
T_cache const* cos_ptr = cache_ptr;
|
||||
T_cache const* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
if constexpr (interleave) {
|
||||
// Perform interleaving. Use pre-computed cos/sin values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread / 2; ++i) {
|
||||
int const idx0 = 2 * i;
|
||||
int const idx1 = 2 * i + 1;
|
||||
|
||||
float const val0 = elements[idx0];
|
||||
float const val1 = elements[idx1];
|
||||
|
||||
int const dim_idx = laneId * numElemsPerThread + idx0;
|
||||
int const half_dim = dim_idx / 2;
|
||||
float const cos_val =
|
||||
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float const sin_val =
|
||||
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[idx0] = val0 * cos_val - val1 * sin_val;
|
||||
elements[idx1] = val0 * sin_val + val1 * cos_val;
|
||||
}
|
||||
} else {
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
// Get the data from the other half of the warp. Use pre-computed cos/sin
|
||||
// values.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < numElemsPerThread; i++) {
|
||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
||||
if (laneId < 16) {
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
dim_idx = (dim_idx * 2) % head_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
// Use pre-computed cos/sin from cache
|
||||
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
|
||||
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
|
||||
|
||||
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
|
||||
}
|
||||
// __shfl_xor_sync does not provide memfence. Need to sync again.
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Store.
|
||||
{
|
||||
vec_T vec;
|
||||
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_packed_elems; i++) {
|
||||
// Convert from float2 back to the specific packed type
|
||||
T2_in packed_val = Converter::convert(
|
||||
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
// Place it into the generic vector
|
||||
*(reinterpret_cast<T2_in*>(&vec) + i) = packed_val;
|
||||
}
|
||||
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Borrowed from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
|
||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||
if (interleave) { \
|
||||
const bool INTERLEAVE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
const bool INTERLEAVE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
template <typename scalar_t_in, typename scalar_t_cache>
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens,
|
||||
int const num_heads_q, int const num_heads_k,
|
||||
int const num_heads_v, int const head_dim,
|
||||
float const eps, void const* q_weight,
|
||||
void const* k_weight, void const* cos_sin_cache,
|
||||
bool const interleave, int64_t const* position_ids,
|
||||
cudaStream_t stream) {
|
||||
constexpr int blockSize = 256;
|
||||
|
||||
int const warpsPerBlock = blockSize / 32;
|
||||
int const totalQKHeads = num_heads_q + num_heads_k;
|
||||
int const totalWarps = num_tokens * totalQKHeads;
|
||||
|
||||
int const gridSize = common::divUp(totalWarps, warpsPerBlock);
|
||||
dim3 gridDim(gridSize);
|
||||
dim3 blockDim(blockSize);
|
||||
|
||||
switch (head_dim) {
|
||||
case 64:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
case 128:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
case 256:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
|
||||
<<<gridDim, blockDim, 0, stream>>>(
|
||||
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
|
||||
k_weight, cos_sin_cache, position_ids, num_tokens);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported head dimension for fusedQKNormRope: ", head_dim);
|
||||
}
|
||||
}
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
void fused_qk_norm_rope(
|
||||
torch::Tensor& qkv, // Combined QKV tensor [num_tokens,
|
||||
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
|
||||
int64_t num_heads_q, // Number of query heads
|
||||
int64_t num_heads_k, // Number of key heads
|
||||
int64_t num_heads_v, // Number of value heads
|
||||
int64_t head_dim, // Dimension per head
|
||||
double eps, // Epsilon for RMS normalization
|
||||
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
|
||||
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
|
||||
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
|
||||
bool is_neox, // Whether RoPE is applied in Neox style
|
||||
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens]
|
||||
) {
|
||||
// Input validation
|
||||
CHECK_INPUT(qkv);
|
||||
CHECK_INPUT(position_ids);
|
||||
CHECK_INPUT(q_weight);
|
||||
CHECK_INPUT(k_weight);
|
||||
CHECK_INPUT(cos_sin_cache);
|
||||
CHECK_TYPE(position_ids, torch::kInt64);
|
||||
|
||||
TORCH_CHECK(qkv.dim() == 2,
|
||||
"QKV tensor must be 2D: [num_tokens, "
|
||||
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
|
||||
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
|
||||
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
|
||||
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
|
||||
TORCH_CHECK(cos_sin_cache.dim() == 2,
|
||||
"Cos/sin cache must be 2D: [max_position, head_dim]");
|
||||
TORCH_CHECK(q_weight.size(0) == head_dim,
|
||||
"Query weights size must match head dimension");
|
||||
TORCH_CHECK(k_weight.size(0) == head_dim,
|
||||
"Key weights size must match head dimension");
|
||||
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
|
||||
"Cos/sin cache dimension must match head_dim");
|
||||
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
|
||||
qkv.scalar_type() == k_weight.scalar_type(),
|
||||
"qkv, q_weight and k_weight must have the same dtype");
|
||||
|
||||
int64_t num_tokens = qkv.size(0);
|
||||
TORCH_CHECK(position_ids.size(0) == num_tokens,
|
||||
"Number of tokens in position_ids must match QKV");
|
||||
|
||||
int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
|
||||
TORCH_CHECK(
|
||||
qkv.size(1) == total_heads * head_dim,
|
||||
"QKV tensor size must match total number of heads and head dimension");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
|
||||
using qkv_scalar_t = scalar_t;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
|
||||
using cache_scalar_t = scalar_t;
|
||||
tensorrt_llm::kernels::launchFusedQKNormRope<qkv_scalar_t,
|
||||
cache_scalar_t>(
|
||||
qkv.data_ptr(), static_cast<int>(num_tokens),
|
||||
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
|
||||
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
|
||||
cos_sin_cache.data_ptr(), !is_neox,
|
||||
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
|
||||
stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#endif // not USE_ROCM
|
||||
@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
|
||||
int64_t num_heads_k, int64_t num_heads_v,
|
||||
int64_t head_dim, double eps, torch::Tensor& q_weight,
|
||||
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
|
||||
bool is_neox, torch::Tensor& position_ids);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
|
||||
@ -175,6 +175,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Function for fused QK Norm and RoPE
|
||||
ops.def(
|
||||
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
|
||||
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
|
||||
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
|
||||
"bool is_neox, Tensor position_ids) -> ()");
|
||||
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
|
||||
#endif
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
|
||||
@ -29,6 +29,22 @@ struct _typeConvert {
|
||||
static constexpr bool exists = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _typeConvert<float> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = float;
|
||||
using packed_hip_type = float2;
|
||||
using packed_hip_type4 = float4; // For 128-bit vectorization
|
||||
|
||||
__device__ static __forceinline__ float convert(hip_type x) { return x; }
|
||||
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
|
||||
return x;
|
||||
}
|
||||
__device__ static __forceinline__ float4 convert(packed_hip_type4 x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template <>
|
||||
@ -37,14 +53,16 @@ struct _typeConvert<c10::Half> {
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
__device__ static __forceinline__ float convert(hip_type x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
__device__ static __forceinline__ hip_type convert(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
__device__ static __forceinline__ packed_hip_type convert(float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
@ -58,16 +76,16 @@ struct _typeConvert<c10::BFloat16> {
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) {
|
||||
__device__ static __forceinline__ float convert(hip_type x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
__device__ static __forceinline__ hip_type convert(float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
__device__ static __forceinline__ packed_hip_type convert(float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
@ -95,10 +113,15 @@ struct alignas(16) _f16Vec {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
if constexpr (std::is_same_v<T2, float2>) {
|
||||
data[i] += other.data[i];
|
||||
data[i + 1] += other.data[i + 1];
|
||||
} else {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
@ -111,10 +134,15 @@ struct alignas(16) _f16Vec {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
if constexpr (std::is_same_v<T2, float2>) {
|
||||
data[i] *= other.data[i];
|
||||
data[i + 1] *= other.data[i + 1];
|
||||
} else {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
|
||||
195
tests/compile/test_qk_norm_rope_fusion.py
Normal file
195
tests/compile/test_qk_norm_rope_fusion.py
Normal file
@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.qk_norm_rope_fusion import (
|
||||
FUSED_QK_ROPE_OP,
|
||||
QKNormRoPEFusionPass,
|
||||
)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RSQRT_OP = torch.ops.aten.rsqrt.default
|
||||
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
|
||||
|
||||
|
||||
class QKNormRoPETestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
vllm_config: VllmConfig,
|
||||
dtype: torch.dtype,
|
||||
prefix: str = "model.layers.0.self_attn.attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = num_heads * head_dim
|
||||
self.kv_size = num_kv_heads * head_dim
|
||||
self.rotary_dim = head_dim
|
||||
self.eps = eps
|
||||
self.dtype = dtype
|
||||
|
||||
# Register layer metadata for the fusion pass via Attention.
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=1.0 / self.head_dim**0.5,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix=prefix,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=self.eps)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000,
|
||||
is_neox_style=is_neox,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.enable_rms_norm_custom_op = self.q_norm.enabled()
|
||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||
|
||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
return q, k, v
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rms_norm_custom_op:
|
||||
ops.append(RMS_OP)
|
||||
else:
|
||||
ops.append(RSQRT_OP)
|
||||
|
||||
if self.enable_rope_custom_op:
|
||||
if self.rotary_emb.use_flashinfer:
|
||||
ops.append(FLASHINFER_ROTARY_OP)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [FUSED_QK_ROPE_OP]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="Only test on cuda platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
):
|
||||
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
|
||||
pytest.skip("fused_qk_norm_rope custom op not available")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
custom_ops: list[str] = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_rope_custom_op:
|
||||
custom_ops.append("+rotary_embedding")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_noop=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
T = 5
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = QKNormRoPETestModel(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
eps=eps,
|
||||
is_neox=is_neox,
|
||||
vllm_config=vllm_config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = QKNormRoPEFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend_baseline = TestBackend(noop_pass, cleanup_pass)
|
||||
|
||||
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
|
||||
pos = torch.arange(T, dtype=torch.long, device=qkv.device)
|
||||
qkv_unfused = qkv.clone()
|
||||
pos_unfused = pos.clone()
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
q_fused, k_fused, v_fused = model_fused(qkv, pos)
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv_unfused, 0)
|
||||
torch._dynamo.mark_dynamic(pos_unfused, 0)
|
||||
model_unfused = torch.compile(model, backend=backend_baseline)
|
||||
q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def _apply_qk_norm_rope(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_norm: RMSNorm,
|
||||
k_norm: RMSNorm,
|
||||
rope: RotaryEmbedding,
|
||||
num_heads_q: int,
|
||||
num_heads_kv: int,
|
||||
head_dim: int,
|
||||
) -> torch.Tensor:
|
||||
q_size = num_heads_q * head_dim
|
||||
kv_size = num_heads_kv * head_dim
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm.forward_native(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm.forward_native(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = rope.forward_native(positions, q, k)
|
||||
return torch.cat([q, k, v], dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="fused_qk_norm_rope custom op requires cuda platform",
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
num_tokens = 4
|
||||
|
||||
total_dim = (num_heads + 2 * num_kv_heads) * head_dim
|
||||
qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device)
|
||||
qkv_fused = qkv_base.clone()
|
||||
positions = torch.arange(num_tokens, dtype=torch.long, device=device)
|
||||
|
||||
q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
q_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
ref_result = _apply_qk_norm_rope(
|
||||
qkv=qkv_base,
|
||||
positions=positions,
|
||||
q_norm=q_norm,
|
||||
k_norm=k_norm,
|
||||
rope=rope,
|
||||
num_heads_q=num_heads,
|
||||
num_heads_kv=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.fused_qk_norm_rope,
|
||||
(
|
||||
qkv_fused.clone(),
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
),
|
||||
)
|
||||
|
||||
torch.ops._C.fused_qk_norm_rope(
|
||||
qkv_fused,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(
|
||||
qkv_fused,
|
||||
ref_result,
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
@ -329,6 +329,7 @@ def rms_norm(
|
||||
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
|
||||
) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||
# If removed, also need to remove contiguous in MatcherRMSNorm
|
||||
input_contiguous = input.contiguous()
|
||||
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
|
||||
|
||||
@ -339,6 +340,34 @@ def fused_add_rms_norm(
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def fused_qk_norm_rope(
|
||||
qkv: torch.Tensor,
|
||||
num_heads_q: int,
|
||||
num_heads_k: int,
|
||||
num_heads_v: int,
|
||||
head_dim: int,
|
||||
eps: float,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
position_ids: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C.fused_qk_norm_rope(
|
||||
qkv,
|
||||
num_heads_q,
|
||||
num_heads_k,
|
||||
num_heads_v,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
position_ids,
|
||||
)
|
||||
|
||||
|
||||
def apply_repetition_penalties_torch(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
|
||||
@ -132,6 +132,23 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
|
||||
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
|
||||
mutated_args = {1: "qkv"}
|
||||
args = (
|
||||
"qkv",
|
||||
"num_heads_q",
|
||||
"num_heads_k",
|
||||
"num_heads_v",
|
||||
"head_dim",
|
||||
"eps",
|
||||
"q_weight",
|
||||
"k_weight",
|
||||
"cos_sin_cache",
|
||||
"is_neox",
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
|
||||
@ -44,6 +44,10 @@ def empty_i32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
def empty_i64(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
@ -18,10 +18,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
@ -58,6 +61,9 @@ class MatcherCustomOp(ABC):
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
|
||||
|
||||
def empty_int64(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
|
||||
|
||||
def empty_f32(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
|
||||
|
||||
@ -66,6 +72,77 @@ class MatcherCustomOp(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
is_neox: bool,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_flashinfer: bool = False,
|
||||
enabled: bool | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RotaryEmbedding.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.is_neox = is_neox
|
||||
self.head_size = head_size
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.kv_size = self.num_kv_heads * self.head_size
|
||||
self.rotary_dim = head_size
|
||||
if use_flashinfer:
|
||||
self.rotary_op = FLASHINFER_ROTARY_OP
|
||||
else:
|
||||
self.rotary_op = ROTARY_OP
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
positions = self.empty_int64(5)
|
||||
query = self.empty(5, self.q_size)
|
||||
key = self.empty(5, self.kv_size)
|
||||
cos_sin_cache = self.empty(4096, self.rotary_dim)
|
||||
return [positions, query, key, cos_sin_cache]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result = auto_functionalized(
|
||||
self.rotary_op,
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
)
|
||||
query_out = result[1]
|
||||
key_out = result[2] if len(result) > 2 else None
|
||||
return query_out, key_out
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
)
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
@ -85,10 +162,12 @@ class MatcherRMSNorm(MatcherCustomOp):
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
# TODO: support non-contiguous input for RMSNorm and remove this
|
||||
input_contiguous = input.contiguous()
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
input=input_contiguous,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
@ -17,6 +17,7 @@ if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
@ -109,6 +110,9 @@ class PostGradPassManager(CustomGraphPass):
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
238
vllm/compilation/qk_norm_rope_fusion.py
Normal file
238
vllm/compilation/qk_norm_rope_fusion.py
Normal file
@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from .fusion import empty_bf16, empty_fp32, empty_i64
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
|
||||
|
||||
|
||||
class QkNormRopePattern:
|
||||
"""
|
||||
Match the unfused sequence in attention blocks and replace with the fused op.
|
||||
|
||||
Unfused (conceptually):
|
||||
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
qh = reshape(q, [-1, num_heads, head_dim])
|
||||
kh = reshape(k, [-1, num_kv_heads, head_dim])
|
||||
qn = rms_norm(qh, q_weight, eps)
|
||||
kn = rms_norm(kh, k_weight, eps)
|
||||
qf = reshape(qn, [-1, num_heads * head_dim])
|
||||
kf = reshape(kn, [-1, num_kv_heads * head_dim])
|
||||
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
|
||||
return qf, kf, v
|
||||
|
||||
Fused replacement:
|
||||
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
|
||||
eps, q_weight, k_weight, cos_sin_cache, is_neox,
|
||||
positions.view(-1))
|
||||
return split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
rope_flashinfer: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
||||
self.is_neox = is_neox
|
||||
self.rope_flashinfer = rope_flashinfer
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=is_neox,
|
||||
head_size=self.head_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_flashinfer=self.rope_flashinfer,
|
||||
)
|
||||
|
||||
def get_inputs(self):
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
|
||||
positions = empty_i64(T)
|
||||
q_weight = empty_bf16(1, self.head_dim)
|
||||
k_weight = empty_bf16(1, self.head_dim)
|
||||
if self.rope_flashinfer:
|
||||
cos_sin_cache = empty_fp32(4096, self.head_dim)
|
||||
else:
|
||||
cos_sin_cache = empty_bf16(4096, self.head_dim)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
q_weight,
|
||||
k_weight,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
):
|
||||
# split qkv -> q,k,v
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q path: view -> RMS -> view back to q.shape
|
||||
q_by_head = q.view(
|
||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
||||
q_flat = q_normed_by_head.view(q.shape)
|
||||
|
||||
# K path: view -> RMS -> view back to k.shape
|
||||
k_by_head = k.view(
|
||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
||||
k_flat = k_normed_by_head.view(k.shape)
|
||||
|
||||
# RoPE: apply to flattened q/k
|
||||
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
):
|
||||
# Run fused qk_norm_rope op
|
||||
result = auto_functionalized(
|
||||
FUSED_QK_ROPE_OP,
|
||||
qkv=qkv,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=self.num_kv_heads,
|
||||
num_heads_v=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
eps=self.eps,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
position_ids=positions.view(-1),
|
||||
)
|
||||
result_qkv = result[1]
|
||||
|
||||
# Split back to q,k,v and return
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
QkNormRopePattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
QkNormRopePattern.fx_view_to_reshape,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qk_norm_rope_fusion_pass"
|
||||
)
|
||||
|
||||
dtype = config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
|
||||
)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
config, Attention
|
||||
)
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for neox in [True, False]:
|
||||
if RotaryEmbedding.enabled():
|
||||
for rope_flashinfer in [False, True]:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
rope_flashinfer=rope_flashinfer,
|
||||
).register(self.patterns)
|
||||
else:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
|
||||
|
||||
def uuid(self):
|
||||
return VllmInductorPass.hash_source(self, QkNormRopePattern)
|
||||
@ -129,6 +129,8 @@ class PassConfig:
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@ -182,6 +184,12 @@ class PassConfig:
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
|
||||
logger.warning_once(
|
||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||
"CUDA. The fusion will be disabled."
|
||||
)
|
||||
self.enable_qk_norm_rope_fusion = False
|
||||
|
||||
|
||||
@config
|
||||
@ -640,6 +648,11 @@ class CompilationConfig:
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
and "combo_kernels" not in self.inductor_compile_config
|
||||
|
||||
@ -98,6 +98,39 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -105,27 +138,15 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
return self.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user