mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:24:54 +08:00
[Kernel] Replaced blockReduce[...] functions with cub::BlockReduce (#7233)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
9984605412
commit
7937009a7e
@ -4,8 +4,8 @@ tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.409
|
||||
value: 0.419
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.406
|
||||
value: 0.416
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
|
||||
89
benchmarks/kernels/benchmark_layernorm.py
Normal file
89
benchmarks/kernels/benchmark_layernorm.py
Normal file
@ -0,0 +1,89 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
layer(x, residual)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the layernorm kernel.")
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--add-residual", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
add_residual=args.add_residual,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters)
|
||||
103
benchmarks/kernels/benchmark_quant.py
Normal file
103
benchmarks/kernels/benchmark_quant.py
Normal file
@ -0,0 +1,103 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(num_tokens: int,
|
||||
hidden_size: int,
|
||||
static_scale: bool,
|
||||
quant_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
if quant_dtype == torch.int8:
|
||||
ops.scaled_int8_quant(x, scale)
|
||||
else:
|
||||
ops.scaled_fp8_quant(x, scale)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported dtype: {dt}")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the quantization (fp8 or int8) kernel.")
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--static-scale", action="store_true")
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=str,
|
||||
choices=["fp8", "int8"],
|
||||
default="int8")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
static_scale=args.static_scale,
|
||||
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters)
|
||||
@ -3,13 +3,16 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = z;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
|
||||
@ -3,7 +3,14 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
absmax_val = val > absmax_val ? val : absmax_val;
|
||||
}
|
||||
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ float block_absmax_val;
|
||||
if (tid == 0) {
|
||||
block_absmax_val = block_absmax_val_maybe;
|
||||
|
||||
@ -7,7 +7,13 @@
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "../../reduction_utils.cuh"
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
@ -215,7 +221,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
if (scale_ub) {
|
||||
|
||||
@ -1,95 +0,0 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T _max(T a, T b) {
|
||||
return max(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T _sum(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
using ReduceFnType = T (*)(T, T);
|
||||
|
||||
// Helper function to return the next largest power of 2
|
||||
static constexpr int _nextPow2(unsigned int num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename T, int numLanes = WARP_SIZE>
|
||||
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
|
||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||
"numLanes is not a positive power of 2!");
|
||||
static_assert(numLanes <= WARP_SIZE);
|
||||
#pragma unroll
|
||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
|
||||
static_assert(maxBlockSize <= 1024);
|
||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||
val = warpReduce<T>(val, fn);
|
||||
// Calculates max number of lanes that need to participate in the last
|
||||
// warpReduce
|
||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||
static __shared__ T shared[maxActiveLanes];
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
int wid = threadIdx.x / WARP_SIZE;
|
||||
if (lane == 0) shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||
: (T)(0.0f);
|
||||
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
|
||||
} else {
|
||||
// A single warpReduce is equal to blockReduce
|
||||
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduceMax(T val) {
|
||||
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
|
||||
}
|
||||
|
||||
template <typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -83,7 +83,7 @@ def test_models(
|
||||
for m in E4M3_KV_MODELS])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user