[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr 2025-11-12 01:00:31 +08:00 committed by GitHub
parent a7ef3eb0cd
commit 68c09efc37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1243 additions and 38 deletions

View File

@ -451,6 +451,7 @@ steps:
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
- pytest -v -s compile/test_aot_compile.py
- pytest -v -s compile/test_qk_norm_rope_fusion.py
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30

View File

@ -265,6 +265,7 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"

View File

@ -0,0 +1,418 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cuda_runtime.h>
#include <type_traits>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#define CHECK_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define FINAL_MASK 0xffffffff
// TODO: suport for AMD ROCM platform
#ifndef USE_ROCM
namespace tensorrt_llm::common {
template <typename T, int num>
struct packed_as;
// Specialization for packed_as used in this kernel.
template <>
struct packed_as<uint, 1> {
using type = uint;
};
template <>
struct packed_as<uint, 2> {
using type = uint2;
};
template <>
struct packed_as<uint, 4> {
using type = uint4;
};
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
template <typename T>
inline __device__ __host__ T divUp(T m, T n) {
return (m + n - 1) / n;
}
} // namespace tensorrt_llm::common
namespace tensorrt_llm::kernels {
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
// with added support for passing the cos_sin_cache as an input.
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
// Perform per-head QK Norm and RoPE in a single kernel.
// scalar_t_in: data type of QKV and RMSNorm weights
// scalar_t_cache: data type of cos/sin cache
// head_dim: the dimension of each head
// interleave: interleave=!is_neox.
template <typename scalar_t_in, typename scalar_t_cache, int head_dim,
bool interleave>
__global__ void fusedQKNormRopeKernel(
void* qkv_void, // Combined QKV tensor
int const num_heads_q, // Number of query heads
int const num_heads_k, // Number of key heads
int const num_heads_v, // Number of value heads
float const eps, // Epsilon for RMS normalization
void const* q_weight_void, // RMSNorm weights for query
void const* k_weight_void, // RMSNorm weights for key
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens
) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
static_assert(Converter::exists,
"Input QKV data type is not supported for this CUDA "
"architecture or toolkit version.");
using T_in = typename Converter::hip_type;
using T2_in = typename Converter::packed_hip_type;
using CacheConverter = vllm::_typeConvert<scalar_t_cache>;
static_assert(CacheConverter::exists,
"Cache data type is not supported for this CUDA architecture "
"or toolkit version.");
using T_cache = typename CacheConverter::hip_type;
T_in* qkv = reinterpret_cast<T_in*>(qkv_void);
T_in const* q_weight = reinterpret_cast<T_in const*>(q_weight_void);
T_in const* k_weight = reinterpret_cast<T_in const*>(k_weight_void);
T_cache const* cos_sin_cache =
reinterpret_cast<T_cache const*>(cos_sin_cache_void);
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
// Calculate global warp index to determine which head/token this warp
// processes
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
// Total number of attention heads (Q and K)
int const total_qk_heads = num_heads_q + num_heads_k;
// Determine which token and head type (Q or K) this warp processes
int const tokenIdx = globalWarpIdx / total_qk_heads;
int const localHeadIdx = globalWarpIdx % total_qk_heads;
// Skip if this warp is assigned beyond the number of tokens
if (tokenIdx >= num_tokens) return;
bool const isQ = localHeadIdx < num_heads_q;
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
static_assert(head_dim % (32 * 2) == 0,
"head_dim must be divisible by 64 (each warp processes one "
"head, and each thread gets even number of "
"elements)");
constexpr int numElemsPerThread = head_dim / 32;
float elements[numElemsPerThread];
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
static_assert(elemSizeBytes % 4 == 0,
"numSizeBytes must be a multiple of 4");
constexpr int vecSize =
elemSizeBytes /
4; // Use packed_as<uint, vecSize> to perform loading/saving.
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
int offsetWarp; // Offset for the warp
if (isQ) {
// Q segment: token offset + head offset within Q segment
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
} else {
// K segment: token offset + entire Q segment + head offset within K
// segment
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
headIdx * head_dim;
}
int offsetThread = offsetWarp + laneId * numElemsPerThread;
// Sum of squares for RMSNorm
float sumOfSquares = 0.0f;
// Load.
{
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Interpret the generic vector chunk as the specific packed type
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
// Convert to float2 for computation
float2 vals = Converter::convert(packed_val);
sumOfSquares += vals.x * vals.x;
sumOfSquares += vals.y * vals.y;
elements[2 * i] = vals.x;
elements[2 * i + 1] = vals.y;
}
}
// Reduce sum across warp using the utility function
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
// Compute RMS normalization factor
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
// Normalize elements
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? Converter::convert(q_weight[dim])
: Converter::convert(k_weight[dim]);
elements[i] *= rms_rcp * weight;
}
// Apply RoPE to normalized elements
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
int64_t pos_id = position_ids[tokenIdx];
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
int const embed_dim = head_dim / 2;
T_cache const* cos_ptr = cache_ptr;
T_cache const* sin_ptr = cache_ptr + embed_dim;
if constexpr (interleave) {
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
float const val0 = elements[idx0];
float const val1 = elements[idx1];
int const dim_idx = laneId * numElemsPerThread + idx0;
int const half_dim = dim_idx / 2;
float const cos_val =
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float const sin_val =
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
elements[idx0] = val0 * cos_val - val1 * sin_val;
elements[idx1] = val0 * sin_val + val1 * cos_val;
}
} else {
// Before data exchange with in warp, we need to sync.
__syncwarp();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
if (laneId < 16) {
elements2[i] = -elements2[i];
}
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % head_dim;
int half_dim = dim_idx / 2;
// Use pre-computed cos/sin from cache
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
}
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp();
}
// Store.
{
vec_T vec;
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Convert from float2 back to the specific packed type
T2_in packed_val = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
// Place it into the generic vector
*(reinterpret_cast<T2_in*>(&vec) + i) = packed_val;
}
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
}
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
}
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRope(void* qkv, int const num_tokens,
int const num_heads_q, int const num_heads_k,
int const num_heads_v, int const head_dim,
float const eps, void const* q_weight,
void const* k_weight, void const* cos_sin_cache,
bool const interleave, int64_t const* position_ids,
cudaStream_t stream) {
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
int const totalQKHeads = num_heads_q + num_heads_k;
int const totalWarps = num_tokens * totalQKHeads;
int const gridSize = common::divUp(totalWarps, warpsPerBlock);
dim3 gridDim(gridSize);
dim3 blockDim(blockSize);
switch (head_dim) {
case 64:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
case 128:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
case 256:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
default:
TORCH_CHECK(false,
"Unsupported head dimension for fusedQKNormRope: ", head_dim);
}
}
} // namespace tensorrt_llm::kernels
void fused_qk_norm_rope(
torch::Tensor& qkv, // Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t num_heads_q, // Number of query heads
int64_t num_heads_k, // Number of key heads
int64_t num_heads_v, // Number of value heads
int64_t head_dim, // Dimension per head
double eps, // Epsilon for RMS normalization
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
bool is_neox, // Whether RoPE is applied in Neox style
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens]
) {
// Input validation
CHECK_INPUT(qkv);
CHECK_INPUT(position_ids);
CHECK_INPUT(q_weight);
CHECK_INPUT(k_weight);
CHECK_INPUT(cos_sin_cache);
CHECK_TYPE(position_ids, torch::kInt64);
TORCH_CHECK(qkv.dim() == 2,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
TORCH_CHECK(cos_sin_cache.dim() == 2,
"Cos/sin cache must be 2D: [max_position, head_dim]");
TORCH_CHECK(q_weight.size(0) == head_dim,
"Query weights size must match head dimension");
TORCH_CHECK(k_weight.size(0) == head_dim,
"Key weights size must match head dimension");
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
"Cos/sin cache dimension must match head_dim");
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
qkv.scalar_type() == k_weight.scalar_type(),
"qkv, q_weight and k_weight must have the same dtype");
int64_t num_tokens = qkv.size(0);
TORCH_CHECK(position_ids.size(0) == num_tokens,
"Number of tokens in position_ids must match QKV");
int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
TORCH_CHECK(
qkv.size(1) == total_heads * head_dim,
"QKV tensor size must match total number of heads and head dimension");
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using qkv_scalar_t = scalar_t;
VLLM_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using cache_scalar_t = scalar_t;
tensorrt_llm::kernels::launchFusedQKNormRope<qkv_scalar_t,
cache_scalar_t>(
qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
stream);
});
});
}
#endif // not USE_ROCM

View File

@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
int64_t num_heads_k, int64_t num_heads_v,
int64_t head_dim, double eps, torch::Tensor& q_weight,
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
bool is_neox, torch::Tensor& position_ids);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,

View File

@ -175,6 +175,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
#ifndef USE_ROCM
// Function for fused QK Norm and RoPE
ops.def(
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
#endif
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "

View File

@ -29,6 +29,22 @@ struct _typeConvert {
static constexpr bool exists = false;
};
template <>
struct _typeConvert<float> {
static constexpr bool exists = true;
using hip_type = float;
using packed_hip_type = float2;
using packed_hip_type4 = float4; // For 128-bit vectorization
__device__ static __forceinline__ float convert(hip_type x) { return x; }
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
return x;
}
__device__ static __forceinline__ float4 convert(packed_hip_type4 x) {
return x;
}
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
@ -37,14 +53,16 @@ struct _typeConvert<c10::Half> {
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
__device__ static __forceinline__ float convert(hip_type x) {
return __half2float(x);
}
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
__device__ static __forceinline__ hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
__device__ static __forceinline__ packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};
@ -58,16 +76,16 @@ struct _typeConvert<c10::BFloat16> {
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
__device__ static __forceinline__ float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
__device__ static __forceinline__ float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
__device__ static __forceinline__ hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
__device__ static __forceinline__ packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
@ -95,10 +113,15 @@ struct alignas(16) _f16Vec {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
if constexpr (std::is_same_v<T2, float2>) {
data[i] += other.data[i];
data[i + 1] += other.data[i + 1];
} else {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
}
} else {
#pragma unroll
@ -111,10 +134,15 @@ struct alignas(16) _f16Vec {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
if constexpr (std::is_same_v<T2, float2>) {
data[i] *= other.data[i];
data[i + 1] *= other.data[i + 1];
} else {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
}
} else {
#pragma unroll

View File

@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.compile.backend import TestBackend
from vllm.attention import Attention, AttentionType
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.qk_norm_rope_fusion import (
FUSED_QK_ROPE_OP,
QKNormRoPEFusionPass,
)
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RSQRT_OP = torch.ops.aten.rsqrt.default
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
class QKNormRoPETestModel(torch.nn.Module):
def __init__(
self,
*,
num_heads: int,
num_kv_heads: int,
head_dim: int,
eps: float,
is_neox: bool,
vllm_config: VllmConfig,
dtype: torch.dtype,
prefix: str = "model.layers.0.self_attn.attn",
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
self.rotary_dim = head_dim
self.eps = eps
self.dtype = dtype
# Register layer metadata for the fusion pass via Attention.
self.attn = Attention(
num_heads=self.num_heads,
head_size=self.head_dim,
scale=1.0 / self.head_dim**0.5,
num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
prefix=prefix,
attn_type=AttentionType.DECODER,
)
self.q_norm = RMSNorm(self.head_dim, eps=self.eps)
self.k_norm = RMSNorm(self.head_dim, eps=self.eps)
self.rotary_emb = RotaryEmbedding(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position_embeddings=4096,
base=10000,
is_neox_style=is_neox,
dtype=self.dtype,
)
self.enable_rms_norm_custom_op = self.q_norm.enabled()
self.enable_rope_custom_op = self.rotary_emb.enabled()
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
return q, k, v
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rms_norm_custom_op:
ops.append(RMS_OP)
else:
ops.append(RSQRT_OP)
if self.enable_rope_custom_op:
if self.rotary_emb.use_flashinfer:
ops.append(FLASHINFER_ROTARY_OP)
else:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
return [FUSED_QK_ROPE_OP]
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("is_neox", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Only test on cuda platform",
)
def test_qk_norm_rope_fusion(
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
):
if not hasattr(torch.ops._C, "fused_qk_norm_rope"):
pytest.skip("fused_qk_norm_rope custom op not available")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
custom_ops: list[str] = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_rope_custom_op:
custom_ops.append("+rotary_embedding")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
enable_noop=True,
),
),
)
num_heads, num_kv_heads, head_dim = 16, 4, 128
T = 5
with set_current_vllm_config(vllm_config):
model = QKNormRoPETestModel(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
eps=eps,
is_neox=is_neox,
vllm_config=vllm_config,
dtype=dtype,
)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = QKNormRoPEFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend_baseline = TestBackend(noop_pass, cleanup_pass)
qkv = torch.randn(T, model.q_size + 2 * model.kv_size)
pos = torch.arange(T, dtype=torch.long, device=qkv.device)
qkv_unfused = qkv.clone()
pos_unfused = pos.clone()
torch._dynamo.mark_dynamic(qkv, 0)
torch._dynamo.mark_dynamic(pos, 0)
model_fused = torch.compile(model, backend=backend)
q_fused, k_fused, v_fused = model_fused(qkv, pos)
torch._dynamo.mark_dynamic(qkv_unfused, 0)
torch._dynamo.mark_dynamic(pos_unfused, 0)
model_unfused = torch.compile(model, backend=backend_baseline)
q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())

View File

@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
IS_NEOX = [True, False]
EPS_VALUES = [1e-5, 1e-6]
SEEDS = [13]
CUDA_DEVICES = ["cuda:0"]
def _apply_qk_norm_rope(
qkv: torch.Tensor,
positions: torch.Tensor,
q_norm: RMSNorm,
k_norm: RMSNorm,
rope: RotaryEmbedding,
num_heads_q: int,
num_heads_kv: int,
head_dim: int,
) -> torch.Tensor:
q_size = num_heads_q * head_dim
kv_size = num_heads_kv * head_dim
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
q_by_head = q_norm.forward_native(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
k_by_head = k_norm.forward_native(k_by_head)
k = k_by_head.view(k.shape)
q, k = rope.forward_native(positions, q, k)
return torch.cat([q, k, v], dim=-1)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="fused_qk_norm_rope custom op requires cuda platform",
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("is_neox", IS_NEOX)
@pytest.mark.parametrize("eps", EPS_VALUES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_fused_qk_norm_rope_matches_reference(
device: str,
dtype: torch.dtype,
is_neox: bool,
eps: float,
seed: int,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
num_heads, num_kv_heads, head_dim = 16, 4, 128
num_tokens = 4
total_dim = (num_heads + 2 * num_kv_heads) * head_dim
qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device)
qkv_fused = qkv_base.clone()
positions = torch.arange(num_tokens, dtype=torch.long, device=device)
q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
q_norm.weight.data.normal_(mean=1.0, std=0.1)
k_norm.weight.data.normal_(mean=1.0, std=0.1)
q_weight = q_norm.weight.data
k_weight = k_norm.weight.data
rope = RotaryEmbedding(
head_size=head_dim,
rotary_dim=head_dim,
max_position_embeddings=4096,
base=10000.0,
is_neox_style=is_neox,
dtype=dtype,
).to(device)
ref_result = _apply_qk_norm_rope(
qkv=qkv_base,
positions=positions,
q_norm=q_norm,
k_norm=k_norm,
rope=rope,
num_heads_q=num_heads,
num_heads_kv=num_kv_heads,
head_dim=head_dim,
)
opcheck(
torch.ops._C.fused_qk_norm_rope,
(
qkv_fused.clone(),
num_heads,
num_kv_heads,
num_kv_heads,
head_dim,
eps,
q_weight,
k_weight,
rope.cos_sin_cache,
is_neox,
positions.view(-1),
),
)
torch.ops._C.fused_qk_norm_rope(
qkv_fused,
num_heads,
num_kv_heads,
num_kv_heads,
head_dim,
eps,
q_weight,
k_weight,
rope.cos_sin_cache,
is_neox,
positions.view(-1),
)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(
qkv_fused,
ref_result,
atol=ATOL,
rtol=RTOL,
)

View File

@ -329,6 +329,7 @@ def rms_norm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
# If removed, also need to remove contiguous in MatcherRMSNorm
input_contiguous = input.contiguous()
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
@ -339,6 +340,34 @@ def fused_add_rms_norm(
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def fused_qk_norm_rope(
qkv: torch.Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
eps: float,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
position_ids: torch.Tensor,
) -> None:
torch.ops._C.fused_qk_norm_rope(
qkv,
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
eps,
q_weight,
k_weight,
cos_sin_cache,
is_neox,
position_ids,
)
def apply_repetition_penalties_torch(
logits: torch.Tensor,
prompt_mask: torch.Tensor,

View File

@ -132,6 +132,23 @@ class FixFunctionalizationPass(VllmInductorPass):
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
else:
continue # skip the count

View File

@ -44,6 +44,10 @@ def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

View File

@ -18,10 +18,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
@ -58,6 +61,9 @@ class MatcherCustomOp(ABC):
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty_int64(self, *args, **kws):
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
@ -66,6 +72,77 @@ class MatcherCustomOp(ABC):
raise NotImplementedError
class MatcherRotaryEmbedding(MatcherCustomOp):
def __init__(
self,
is_neox: bool,
head_size: int,
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
super().__init__(enabled)
self.is_neox = is_neox
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_size
self.kv_size = self.num_kv_heads * self.head_size
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
else:
self.rotary_op = ROTARY_OP
def inputs(self) -> list[torch.Tensor]:
positions = self.empty_int64(5)
query = self.empty(5, self.q_size)
key = self.empty(5, self.kv_size)
cos_sin_cache = self.empty(4096, self.rotary_dim)
return [positions, query, key, cos_sin_cache]
def forward_custom(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result = auto_functionalized(
self.rotary_op,
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
)
query_out = result[1]
key_out = result[2] if len(result) > 2 else None
return query_out, key_out
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return RotaryEmbedding.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox,
)
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
@ -85,10 +162,12 @@ class MatcherRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
# TODO: support non-contiguous input for RMSNorm and remove this
input_contiguous = input.contiguous()
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input,
input=input_contiguous,
weight=weight,
epsilon=self.epsilon,
)

View File

@ -17,6 +17,7 @@ if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
@ -109,6 +110,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)

View File

@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from .fusion import empty_bf16, empty_fp32, empty_i64
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
class QkNormRopePattern:
"""
Match the unfused sequence in attention blocks and replace with the fused op.
Unfused (conceptually):
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
qh = reshape(q, [-1, num_heads, head_dim])
kh = reshape(k, [-1, num_kv_heads, head_dim])
qn = rms_norm(qh, q_weight, eps)
kn = rms_norm(kh, k_weight, eps)
qf = reshape(qn, [-1, num_heads * head_dim])
kf = reshape(kn, [-1, num_kv_heads * head_dim])
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
return qf, kf, v
Fused replacement:
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
eps, q_weight, k_weight, cos_sin_cache, is_neox,
positions.view(-1))
return split(qkv, [qsz, kvsz, kvsz], -1)
"""
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
eps: float,
is_neox: bool,
rope_flashinfer: bool = False,
) -> None:
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=is_neox,
head_size=self.head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self):
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
positions = empty_i64(T)
q_weight = empty_bf16(1, self.head_dim)
k_weight = empty_bf16(1, self.head_dim)
if self.rope_flashinfer:
cos_sin_cache = empty_fp32(4096, self.head_dim)
else:
cos_sin_cache = empty_bf16(4096, self.head_dim)
return [
qkv,
positions,
q_weight,
k_weight,
cos_sin_cache,
]
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q path: view -> RMS -> view back to q.shape
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
return q_rope, k_rope, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
qkv=qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.eps,
q_weight=q_weight,
k_weight=k_weight,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
position_ids=positions.view(-1),
)
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
QkNormRopePattern.wrap_trace_fn(
pm.fwd_only,
QkNormRopePattern.fx_view_to_reshape,
),
pm_pass,
)
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
)
dtype = config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.warning_once(
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
)
return
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
if len(attn_layers) == 0:
logger.warning_once(
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(self, QkNormRopePattern)

View File

@ -129,6 +129,8 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@ -182,6 +184,12 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False
@config
@ -640,6 +648,11 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
if self.pass_config.enable_qk_norm_rope_fusion:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")
and "combo_kernels" not in self.inductor_compile_config

View File

@ -98,6 +98,39 @@ class RotaryEmbedding(RotaryEmbeddingBase):
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
@staticmethod
def forward_static(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
rotary_dim: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_native(
self,
positions: torch.Tensor,
@ -105,27 +138,15 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
return self.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
self.cos_sin_cache,
self.is_neox_style,
)
def forward_cuda(
self,