diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 25f711dd60b37..8d2a7bc5a8029 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -451,6 +451,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_qk_norm_rope_fusion.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e9fa63b178ea..5cddf81a4b4aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/cuda_view.cu" diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu new file mode 100644 index 0000000000000..cbd23975a7739 --- /dev/null +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "type_convert.cuh" + +#define CHECK_TYPE(x, st) \ + TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \ + ", while ", st, " is expected") +#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_TH_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define FINAL_MASK 0xffffffff + +// TODO: suport for AMD ROCM platform +#ifndef USE_ROCM +namespace tensorrt_llm::common { +template +struct packed_as; +// Specialization for packed_as used in this kernel. +template <> +struct packed_as { + using type = uint; +}; + +template <> +struct packed_as { + using type = uint2; +}; + +template <> +struct packed_as { + using type = uint4; +}; + +template +__inline__ __device__ T warpReduceSum(T val) { + #pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +template +inline __device__ __host__ T divUp(T m, T n) { + return (m + n - 1) / n; +} + +} // namespace tensorrt_llm::common + +namespace tensorrt_llm::kernels { +// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation, +// with added support for passing the cos_sin_cache as an input. +// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu + +// Perform per-head QK Norm and RoPE in a single kernel. +// scalar_t_in: data type of QKV and RMSNorm weights +// scalar_t_cache: data type of cos/sin cache +// head_dim: the dimension of each head +// interleave: interleave=!is_neox. +template +__global__ void fusedQKNormRopeKernel( + void* qkv_void, // Combined QKV tensor + int const num_heads_q, // Number of query heads + int const num_heads_k, // Number of key heads + int const num_heads_v, // Number of value heads + float const eps, // Epsilon for RMS normalization + void const* q_weight_void, // RMSNorm weights for query + void const* k_weight_void, // RMSNorm weights for key + void const* cos_sin_cache_void, // Pre-computed cos/sin cache + int64_t const* position_ids, // Position IDs for RoPE + int const num_tokens // Number of tokens +) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr ((std::is_same_v) || + std::is_same_v) { + return; + } else { + #endif + + using Converter = vllm::_typeConvert; + static_assert(Converter::exists, + "Input QKV data type is not supported for this CUDA " + "architecture or toolkit version."); + using T_in = typename Converter::hip_type; + using T2_in = typename Converter::packed_hip_type; + + using CacheConverter = vllm::_typeConvert; + static_assert(CacheConverter::exists, + "Cache data type is not supported for this CUDA architecture " + "or toolkit version."); + using T_cache = typename CacheConverter::hip_type; + + T_in* qkv = reinterpret_cast(qkv_void); + T_in const* q_weight = reinterpret_cast(q_weight_void); + T_in const* k_weight = reinterpret_cast(k_weight_void); + T_cache const* cos_sin_cache = + reinterpret_cast(cos_sin_cache_void); + + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + + // Calculate global warp index to determine which head/token this warp + // processes + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + // Total number of attention heads (Q and K) + int const total_qk_heads = num_heads_q + num_heads_k; + + // Determine which token and head type (Q or K) this warp processes + int const tokenIdx = globalWarpIdx / total_qk_heads; + int const localHeadIdx = globalWarpIdx % total_qk_heads; + + // Skip if this warp is assigned beyond the number of tokens + if (tokenIdx >= num_tokens) return; + + bool const isQ = localHeadIdx < num_heads_q; + int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q; + + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + + static_assert(head_dim % (32 * 2) == 0, + "head_dim must be divisible by 64 (each warp processes one " + "head, and each thread gets even number of " + "elements)"); + constexpr int numElemsPerThread = head_dim / 32; + float elements[numElemsPerThread]; + constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16); + static_assert(elemSizeBytes % 4 == 0, + "numSizeBytes must be a multiple of 4"); + constexpr int vecSize = + elemSizeBytes / + 4; // Use packed_as to perform loading/saving. + using vec_T = typename tensorrt_llm::common::packed_as::type; + + int offsetWarp; // Offset for the warp + if (isQ) { + // Q segment: token offset + head offset within Q segment + offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim; + } else { + // K segment: token offset + entire Q segment + head offset within K + // segment + offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + + headIdx * head_dim; + } + int offsetThread = offsetWarp + laneId * numElemsPerThread; + + // Sum of squares for RMSNorm + float sumOfSquares = 0.0f; + + // Load. + { + vec_T vec = *reinterpret_cast(&qkv[offsetThread]); + constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); + #pragma unroll + for (int i = 0; i < num_packed_elems; i++) { + // Interpret the generic vector chunk as the specific packed type + T2_in packed_val = *(reinterpret_cast(&vec) + i); + // Convert to float2 for computation + float2 vals = Converter::convert(packed_val); + sumOfSquares += vals.x * vals.x; + sumOfSquares += vals.y * vals.y; + + elements[2 * i] = vals.x; + elements[2 * i + 1] = vals.y; + } + } + + // Reduce sum across warp using the utility function + sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares); + + // Compute RMS normalization factor + float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); + + // Normalize elements + #pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + int dim = laneId * numElemsPerThread + i; + float weight = isQ ? Converter::convert(q_weight[dim]) + : Converter::convert(k_weight[dim]); + elements[i] *= rms_rcp * weight; + } + + // Apply RoPE to normalized elements + float elements2[numElemsPerThread]; // Additional buffer required for RoPE. + + int64_t pos_id = position_ids[tokenIdx]; + + // Calculate cache pointer for this position - similar to + // pos_encoding_kernels.cu + T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim; + int const embed_dim = head_dim / 2; + T_cache const* cos_ptr = cache_ptr; + T_cache const* sin_ptr = cache_ptr + embed_dim; + + if constexpr (interleave) { + // Perform interleaving. Use pre-computed cos/sin values. + #pragma unroll + for (int i = 0; i < numElemsPerThread / 2; ++i) { + int const idx0 = 2 * i; + int const idx1 = 2 * i + 1; + + float const val0 = elements[idx0]; + float const val1 = elements[idx1]; + + int const dim_idx = laneId * numElemsPerThread + idx0; + int const half_dim = dim_idx / 2; + float const cos_val = + CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float const sin_val = + CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[idx0] = val0 * cos_val - val1 * sin_val; + elements[idx1] = val0 * sin_val + val1 * cos_val; + } + } else { + // Before data exchange with in warp, we need to sync. + __syncwarp(); + // Get the data from the other half of the warp. Use pre-computed cos/sin + // values. + #pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); + if (laneId < 16) { + elements2[i] = -elements2[i]; + } + + int dim_idx = laneId * numElemsPerThread + i; + dim_idx = (dim_idx * 2) % head_dim; + int half_dim = dim_idx / 2; + // Use pre-computed cos/sin from cache + float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + } + // __shfl_xor_sync does not provide memfence. Need to sync again. + __syncwarp(); + } + + // Store. + { + vec_T vec; + constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); + #pragma unroll + for (int i = 0; i < num_packed_elems; i++) { + // Convert from float2 back to the specific packed type + T2_in packed_val = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + // Place it into the generic vector + *(reinterpret_cast(&vec) + i) = packed_val; + } + *reinterpret_cast(&qkv[offsetThread]) = vec; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } + #endif +} + + // Borrowed from + // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 + #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +template +void launchFusedQKNormRope(void* qkv, int const num_tokens, + int const num_heads_q, int const num_heads_k, + int const num_heads_v, int const head_dim, + float const eps, void const* q_weight, + void const* k_weight, void const* cos_sin_cache, + bool const interleave, int64_t const* position_ids, + cudaStream_t stream) { + constexpr int blockSize = 256; + + int const warpsPerBlock = blockSize / 32; + int const totalQKHeads = num_heads_q + num_heads_k; + int const totalWarps = num_tokens * totalQKHeads; + + int const gridSize = common::divUp(totalWarps, warpsPerBlock); + dim3 gridDim(gridSize); + dim3 blockDim(blockSize); + + switch (head_dim) { + case 64: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + case 128: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + case 256: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + default: + TORCH_CHECK(false, + "Unsupported head dimension for fusedQKNormRope: ", head_dim); + } +} +} // namespace tensorrt_llm::kernels + +void fused_qk_norm_rope( + torch::Tensor& qkv, // Combined QKV tensor [num_tokens, + // (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] + torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim] + bool is_neox, // Whether RoPE is applied in Neox style + torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] +) { + // Input validation + CHECK_INPUT(qkv); + CHECK_INPUT(position_ids); + CHECK_INPUT(q_weight); + CHECK_INPUT(k_weight); + CHECK_INPUT(cos_sin_cache); + CHECK_TYPE(position_ids, torch::kInt64); + + TORCH_CHECK(qkv.dim() == 2, + "QKV tensor must be 2D: [num_tokens, " + "(num_heads_q+num_heads_k+num_heads_v)*head_dim]"); + TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]"); + TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); + TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); + TORCH_CHECK(cos_sin_cache.dim() == 2, + "Cos/sin cache must be 2D: [max_position, head_dim]"); + TORCH_CHECK(q_weight.size(0) == head_dim, + "Query weights size must match head dimension"); + TORCH_CHECK(k_weight.size(0) == head_dim, + "Key weights size must match head dimension"); + TORCH_CHECK(cos_sin_cache.size(1) == head_dim, + "Cos/sin cache dimension must match head_dim"); + TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() && + qkv.scalar_type() == k_weight.scalar_type(), + "qkv, q_weight and k_weight must have the same dtype"); + + int64_t num_tokens = qkv.size(0); + TORCH_CHECK(position_ids.size(0) == num_tokens, + "Number of tokens in position_ids must match QKV"); + + int64_t total_heads = num_heads_q + num_heads_k + num_heads_v; + TORCH_CHECK( + qkv.size(1) == total_heads * head_dim, + "QKV tensor size must match total number of heads and head dimension"); + + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + + VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using qkv_scalar_t = scalar_t; + VLLM_DISPATCH_FLOATING_TYPES( + cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using cache_scalar_t = scalar_t; + tensorrt_llm::kernels::launchFusedQKNormRope( + qkv.data_ptr(), static_cast(num_tokens), + static_cast(num_heads_q), static_cast(num_heads_k), + static_cast(num_heads_v), static_cast(head_dim), + static_cast(eps), q_weight.data_ptr(), k_weight.data_ptr(), + cos_sin_cache.data_ptr(), !is_neox, + reinterpret_cast(position_ids.data_ptr()), + stream); + }); + }); +} + +#endif // not USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 3f5cb799b774c..f8bdc61aaa8ec 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, + int64_t num_heads_k, int64_t num_heads_v, + int64_t head_dim, double eps, torch::Tensor& q_weight, + torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, + bool is_neox, torch::Tensor& position_ids); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9c0f524dcab11..d4a69cbe7971d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,6 +175,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); +#ifndef USE_ROCM + // Function for fused QK Norm and RoPE + ops.def( + "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " + "int num_heads_k, int num_heads_v, int head_dim, float eps, " + "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " + "bool is_neox, Tensor position_ids) -> ()"); + ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); +#endif + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 21b9d0ae515df..6da06f1e66cf5 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -29,6 +29,22 @@ struct _typeConvert { static constexpr bool exists = false; }; +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = float; + using packed_hip_type = float2; + using packed_hip_type4 = float4; // For 128-bit vectorization + + __device__ static __forceinline__ float convert(hip_type x) { return x; } + __device__ static __forceinline__ float2 convert(packed_hip_type x) { + return x; + } + __device__ static __forceinline__ float4 convert(packed_hip_type4 x) { + return x; + } +}; + #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> @@ -37,14 +53,16 @@ struct _typeConvert { using hip_type = __half; using packed_hip_type = __half2; - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { + __device__ static __forceinline__ float convert(hip_type x) { + return __half2float(x); + } + __device__ static __forceinline__ float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { + __device__ static __forceinline__ hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { + __device__ static __forceinline__ packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }; @@ -58,16 +76,16 @@ struct _typeConvert { using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { + __device__ static __forceinline__ float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { + __device__ static __forceinline__ float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { + __device__ static __forceinline__ hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { + __device__ static __forceinline__ packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; @@ -95,10 +113,15 @@ struct alignas(16) _f16Vec { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; + if constexpr (std::is_same_v) { + data[i] += other.data[i]; + data[i + 1] += other.data[i + 1]; + } else { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } } } else { #pragma unroll @@ -111,10 +134,15 @@ struct alignas(16) _f16Vec { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; + if constexpr (std::is_same_v) { + data[i] *= other.data[i]; + data[i + 1] *= other.data[i + 1]; + } else { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } } } else { #pragma unroll diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py new file mode 100644 index 0000000000000..973123a3af920 --- /dev/null +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.compile.backend import TestBackend +from vllm.attention import Attention, AttentionType +from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.qk_norm_rope_fusion import ( + FUSED_QK_ROPE_OP, + QKNormRoPEFusionPass, +) +from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + +RSQRT_OP = torch.ops.aten.rsqrt.default +INDEX_SELECT_OP = torch.ops.aten.index.Tensor + + +class QKNormRoPETestModel(torch.nn.Module): + def __init__( + self, + *, + num_heads: int, + num_kv_heads: int, + head_dim: int, + eps: float, + is_neox: bool, + vllm_config: VllmConfig, + dtype: torch.dtype, + prefix: str = "model.layers.0.self_attn.attn", + ) -> None: + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.rotary_dim = head_dim + self.eps = eps + self.dtype = dtype + + # Register layer metadata for the fusion pass via Attention. + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=1.0 / self.head_dim**0.5, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix=prefix, + attn_type=AttentionType.DECODER, + ) + + self.q_norm = RMSNorm(self.head_dim, eps=self.eps) + self.k_norm = RMSNorm(self.head_dim, eps=self.eps) + self.rotary_emb = RotaryEmbedding( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position_embeddings=4096, + base=10000, + is_neox_style=is_neox, + dtype=self.dtype, + ) + self.enable_rms_norm_custom_op = self.q_norm.enabled() + self.enable_rope_custom_op = self.rotary_emb.enabled() + + def forward(self, qkv: torch.Tensor, positions: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + return q, k, v + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + ops = [] + if self.enable_rms_norm_custom_op: + ops.append(RMS_OP) + else: + ops.append(RSQRT_OP) + + if self.enable_rope_custom_op: + if self.rotary_emb.use_flashinfer: + ops.append(FLASHINFER_ROTARY_OP) + else: + ops.append(ROTARY_OP) + else: + ops.append(INDEX_SELECT_OP) + return ops + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [FUSED_QK_ROPE_OP] + + +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("is_neox", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_rope_custom_op", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Only test on cuda platform", +) +def test_qk_norm_rope_fusion( + eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype +): + if not hasattr(torch.ops._C, "fused_qk_norm_rope"): + pytest.skip("fused_qk_norm_rope custom op not available") + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + custom_ops: list[str] = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_rope_custom_op: + custom_ops.append("+rotary_embedding") + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig( + enable_qk_norm_rope_fusion=True, + enable_noop=True, + ), + ), + ) + + num_heads, num_kv_heads, head_dim = 16, 4, 128 + T = 5 + + with set_current_vllm_config(vllm_config): + model = QKNormRoPETestModel( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + eps=eps, + is_neox=is_neox, + vllm_config=vllm_config, + dtype=dtype, + ) + + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = QKNormRoPEFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend_baseline = TestBackend(noop_pass, cleanup_pass) + + qkv = torch.randn(T, model.q_size + 2 * model.kv_size) + pos = torch.arange(T, dtype=torch.long, device=qkv.device) + qkv_unfused = qkv.clone() + pos_unfused = pos.clone() + + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + model_fused = torch.compile(model, backend=backend) + q_fused, k_fused, v_fused = model_fused(qkv, pos) + + torch._dynamo.mark_dynamic(qkv_unfused, 0) + torch._dynamo.mark_dynamic(pos_unfused, 0) + model_unfused = torch.compile(model, backend=backend_baseline) + q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL) + + assert fusion_pass.matched_count == 1 + + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py new file mode 100644 index 0000000000000..88bb7691ec3bc --- /dev/null +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +IS_NEOX = [True, False] +EPS_VALUES = [1e-5, 1e-6] +SEEDS = [13] +CUDA_DEVICES = ["cuda:0"] + + +def _apply_qk_norm_rope( + qkv: torch.Tensor, + positions: torch.Tensor, + q_norm: RMSNorm, + k_norm: RMSNorm, + rope: RotaryEmbedding, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, +) -> torch.Tensor: + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + + q, k = rope.forward_native(positions, q, k) + return torch.cat([q, k, v], dim=-1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="fused_qk_norm_rope custom op requires cuda platform", +) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox", IS_NEOX) +@pytest.mark.parametrize("eps", EPS_VALUES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_fused_qk_norm_rope_matches_reference( + device: str, + dtype: torch.dtype, + is_neox: bool, + eps: float, + seed: int, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + num_heads, num_kv_heads, head_dim = 16, 4, 128 + num_tokens = 4 + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + qkv_fused = qkv_base.clone() + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_norm.weight.data.normal_(mean=1.0, std=0.1) + k_norm.weight.data.normal_(mean=1.0, std=0.1) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=head_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + ref_result = _apply_qk_norm_rope( + qkv=qkv_base, + positions=positions, + q_norm=q_norm, + k_norm=k_norm, + rope=rope, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + opcheck( + torch.ops._C.fused_qk_norm_rope, + ( + qkv_fused.clone(), + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ), + ) + + torch.ops._C.fused_qk_norm_rope( + qkv_fused, + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close( + qkv_fused, + ref_result, + atol=ATOL, + rtol=RTOL, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 36aab503dee70..136a3193efb5e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -329,6 +329,7 @@ def rms_norm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float ) -> None: # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + # If removed, also need to remove contiguous in MatcherRMSNorm input_contiguous = input.contiguous() torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) @@ -339,6 +340,34 @@ def fused_add_rms_norm( torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def fused_qk_norm_rope( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, +) -> None: + torch.ops._C.fused_qk_norm_rope( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + cos_sin_cache, + is_neox, + position_ids, + ) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 29462d9ff0e50..126ad35e527ae 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -132,6 +132,23 @@ class FixFunctionalizationPass(VllmInductorPass): "input_global_scale", ), ) + # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. + elif at_target == torch.ops._C.fused_qk_norm_rope.default: + mutated_args = {1: "qkv"} + args = ( + "qkv", + "num_heads_q", + "num_heads_k", + "num_heads_v", + "head_dim", + "eps", + "q_weight", + "k_weight", + "cos_sin_cache", + "is_neox", + "position_ids", + ) + self.defunctionalize(graph, node, mutated_args=mutated_args, args=args) else: continue # skip the count diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8f0ad2d69fbec..1d6e297b495eb 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -44,6 +44,10 @@ def empty_i32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") +def empty_i64(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") + + RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 383fe6033a6df..38eb4e5301a18 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -18,10 +18,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, kNvfp4Quant, ) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +ROTARY_OP = torch.ops._C.rotary_embedding.default +FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 @@ -58,6 +61,9 @@ class MatcherCustomOp(ABC): def empty(self, *args, **kws): return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + def empty_int64(self, *args, **kws): + return torch.empty(*args, dtype=torch.int64, device=self.device, **kws) + def empty_f32(self, *args, **kws): return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) @@ -66,6 +72,77 @@ class MatcherCustomOp(ABC): raise NotImplementedError +class MatcherRotaryEmbedding(MatcherCustomOp): + def __init__( + self, + is_neox: bool, + head_size: int, + num_heads: int, + num_kv_heads: int, + use_flashinfer: bool = False, + enabled: bool | None = None, + ) -> None: + if enabled is None: + enabled = RotaryEmbedding.enabled() + + super().__init__(enabled) + self.is_neox = is_neox + self.head_size = head_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_size + self.kv_size = self.num_kv_heads * self.head_size + self.rotary_dim = head_size + if use_flashinfer: + self.rotary_op = FLASHINFER_ROTARY_OP + else: + self.rotary_op = ROTARY_OP + + def inputs(self) -> list[torch.Tensor]: + positions = self.empty_int64(5) + query = self.empty(5, self.q_size) + key = self.empty(5, self.kv_size) + cos_sin_cache = self.empty(4096, self.rotary_dim) + return [positions, query, key, cos_sin_cache] + + def forward_custom( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + result = auto_functionalized( + self.rotary_op, + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + ) + query_out = result[1] + key_out = result[2] if len(result) > 2 else None + return query_out, key_out + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return RotaryEmbedding.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + cos_sin_cache, + self.is_neox, + ) + + class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: bool | None = None): if enabled is None: @@ -85,10 +162,12 @@ class MatcherRMSNorm(MatcherCustomOp): weight: torch.Tensor, ) -> torch.Tensor: result = torch.empty_like(input) + # TODO: support non-contiguous input for RMSNorm and remove this + input_contiguous = input.contiguous() _, result = auto_functionalized( RMS_OP, result=result, - input=input, + input=input_contiguous, weight=weight, epsilon=self.epsilon, ) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index dfda2adf1d3b0..0c2210d72ce07 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -17,6 +17,7 @@ if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass + from .qk_norm_rope_fusion import QKNormRoPEFusionPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass @@ -109,6 +110,9 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] + if self.pass_config.enable_qk_norm_rope_fusion: + self.passes += [QKNormRoPEFusionPass(config)] + # needs a functional graph self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/qk_norm_rope_fusion.py new file mode 100644 index 0000000000000..e3c399e079063 --- /dev/null +++ b/vllm/compilation/qk_norm_rope_fusion.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.attention import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + +from .fusion import empty_bf16, empty_fp32, empty_i64 +from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + +FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default + + +class QkNormRopePattern: + """ + Match the unfused sequence in attention blocks and replace with the fused op. + + Unfused (conceptually): + q, k, v = split(qkv, [qsz, kvsz, kvsz], -1) + qh = reshape(q, [-1, num_heads, head_dim]) + kh = reshape(k, [-1, num_kv_heads, head_dim]) + qn = rms_norm(qh, q_weight, eps) + kn = rms_norm(kh, k_weight, eps) + qf = reshape(qn, [-1, num_heads * head_dim]) + kf = reshape(kn, [-1, num_kv_heads * head_dim]) + qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox) + return qf, kf, v + + Fused replacement: + fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim, + eps, q_weight, k_weight, cos_sin_cache, is_neox, + positions.view(-1)) + return split(qkv, [qsz, kvsz, kvsz], -1) + """ + + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + eps: float, + is_neox: bool, + rope_flashinfer: bool = False, + ) -> None: + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + self.rmsnorm_matcher = MatcherRMSNorm(eps) + self.is_neox = is_neox + self.rope_flashinfer = rope_flashinfer + self.rope_matcher = MatcherRotaryEmbedding( + is_neox=is_neox, + head_size=self.head_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + use_flashinfer=self.rope_flashinfer, + ) + + def get_inputs(self): + # Sample inputs to help pattern tracing + T = 5 + qkv = empty_bf16(T, self.q_size + 2 * self.kv_size) + positions = empty_i64(T) + q_weight = empty_bf16(1, self.head_dim) + k_weight = empty_bf16(1, self.head_dim) + if self.rope_flashinfer: + cos_sin_cache = empty_fp32(4096, self.head_dim) + else: + cos_sin_cache = empty_bf16(4096, self.head_dim) + return [ + qkv, + positions, + q_weight, + k_weight, + cos_sin_cache, + ] + + @staticmethod + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): + def wrapped(*args, **kwargs): + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + + view_to_reshape(gm) + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ): + # split qkv -> q,k,v + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q path: view -> RMS -> view back to q.shape + q_by_head = q.view( + *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim + ) + q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) + q_flat = q_normed_by_head.view(q.shape) + + # K path: view -> RMS -> view back to k.shape + k_by_head = k.view( + *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim + ) + k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) + k_flat = k_normed_by_head.view(k.shape) + + # RoPE: apply to flattened q/k + q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache) + return q_rope, k_rope, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ): + # Run fused qk_norm_rope op + result = auto_functionalized( + FUSED_QK_ROPE_OP, + qkv=qkv, + num_heads_q=self.num_heads, + num_heads_k=self.num_kv_heads, + num_heads_v=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.eps, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + position_ids=positions.view(-1), + ) + result_qkv = result[1] + + # Split back to q,k,v and return + return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # NOTE: use fx_view_to_reshape to unify view/reshape to simplify + # pattern and increase matching opportunities + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + QkNormRopePattern.wrap_trace_fn( + pm.fwd_only, + QkNormRopePattern.fx_view_to_reshape, + ), + pm_pass, + ) + + +class QKNormRoPEFusionPass(VllmPatternMatcherPass): + """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.""" + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="qk_norm_rope_fusion_pass" + ) + + dtype = config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logger.warning_once( + "QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype + ) + return + + # use one attn layer to get meta (such as head_dim) for QkNormRopePattern + attn_layers: dict[str, Attention] = get_layers_from_vllm_config( + config, Attention + ) + if len(attn_layers) == 0: + logger.warning_once( + "QK Norm+RoPE fusion enabled, but no Attention layers were discovered." + ) + return + layer = next(iter(attn_layers.values())) + + for epsilon in [1e-5, 1e-6]: + for neox in [True, False]: + if RotaryEmbedding.enabled(): + for rope_flashinfer in [False, True]: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + rope_flashinfer=rope_flashinfer, + ).register(self.patterns) + else: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count) + + def uuid(self): + return VllmInductorPass.hash_source(self, QkNormRopePattern) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 92cf16f259fe7..9c9557df4e738 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -129,6 +129,8 @@ class PassConfig: 8: 1, # 1MB }, }, where key is the device capability""" + enable_qk_norm_rope_fusion: bool = False + """Whether to enable the fused Q/K RMSNorm + RoPE pass.""" # TODO(luka) better pass enabling system. @@ -182,6 +184,12 @@ class PassConfig: "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" ) + if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda(): + logger.warning_once( + "QK Norm + RoPE fusion enabled but the current platform is not " + "CUDA. The fusion will be disabled." + ) + self.enable_qk_norm_rope_fusion = False @config @@ -640,6 +648,11 @@ class CompilationConfig: if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) + if self.pass_config.enable_qk_norm_rope_fusion: + # TODO(zhuhaoran): support rope native forward match and remove this. + # Linked issue: https://github.com/vllm-project/vllm/issues/28042 + self.custom_ops.append("+rotary_embedding") + if ( is_torch_equal_or_newer("2.9.0.dev") and "combo_kernels" not in self.inductor_compile_config diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 2ef54e75df44e..ce4f40680b0a3 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -98,6 +98,39 @@ class RotaryEmbedding(RotaryEmbeddingBase): head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) + @staticmethod + def forward_static( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + head_size: int, + rotary_dim: int, + cos_sin_cache: torch.Tensor, + is_neox_style: bool, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """A PyTorch-native implementation of forward().""" + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + # key may be None in some cases, e.g. cross-layer KV sharing + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def forward_native( self, positions: torch.Tensor, @@ -105,27 +138,15 @@ class RotaryEmbedding(RotaryEmbeddingBase): key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - # key may be None in some cases, e.g. cross-layer KV sharing - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key + return self.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + self.cos_sin_cache, + self.is_neox_style, + ) def forward_cuda( self,