mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:26:15 +08:00
248 lines
9.3 KiB
Plaintext
248 lines
9.3 KiB
Plaintext
#include "common.cuh"
|
|
#include "dispatch_utils.h"
|
|
#include "../../cub_helpers.h"
|
|
#include "../vectorization_utils.cuh"
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <ATen/cuda/Exceptions.h>
|
|
|
|
namespace vllm {
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void scaled_fp8_quant_kernel_strided(
|
|
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
|
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
|
int64_t out_row_stride) {
|
|
const int64_t token_idx = blockIdx.x; // one token per block
|
|
const int tid = threadIdx.x;
|
|
|
|
const scalar_t* token_in = input + token_idx * in_row_stride;
|
|
fp8_type* token_out = out + token_idx * out_row_stride;
|
|
|
|
const float inv_scale = 1.0f / (*scale);
|
|
|
|
vectorize_with_alignment<16>(
|
|
token_in, token_out, hidden_size, tid, blockDim.x,
|
|
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
|
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
|
inv_scale);
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void segmented_max_reduction_strided(
|
|
float* __restrict__ scale, const scalar_t* __restrict__ input,
|
|
int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
|
|
__shared__ float cache[256];
|
|
const int tid = threadIdx.x;
|
|
int64_t token_idx = blockIdx.x;
|
|
|
|
// one block per token. Guard in case gridDim.x > num_tokens.
|
|
if (token_idx >= num_tokens) {
|
|
return;
|
|
}
|
|
|
|
const scalar_t* row_ptr = input + token_idx * in_row_stride;
|
|
|
|
// each thread scans elements of the row in a strided fashion.
|
|
float thread_max = 0.0f;
|
|
for (int e = tid; e < hidden_size; e += blockDim.x) {
|
|
float v = fabsf(static_cast<float>(row_ptr[e]));
|
|
thread_max = fmaxf(thread_max, v);
|
|
}
|
|
|
|
cache[tid] = thread_max;
|
|
__syncthreads();
|
|
|
|
// parallel reduction to find row max.
|
|
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
|
|
if (tid < offset) {
|
|
cache[tid] = fmaxf(cache[tid], cache[tid + offset]);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// thread 0 updates global scale (per-tensor) atomically.
|
|
if (tid == 0) {
|
|
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
|
}
|
|
}
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void scaled_fp8_quant_kernel_strided_dynamic(
|
|
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
|
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
|
int64_t out_row_stride) {
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int tid = threadIdx.x;
|
|
|
|
const scalar_t* token_in = input + token_idx * in_row_stride;
|
|
fp8_type* token_out = out + token_idx * out_row_stride;
|
|
|
|
const float reciprocal_scale = 1.0f / (*scale);
|
|
vectorize_with_alignment<16>(
|
|
token_in, token_out, hidden_size, tid, blockDim.x,
|
|
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
|
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
|
reciprocal_scale);
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
|
fp8_type* __restrict__ out, float* __restrict__ scale,
|
|
const scalar_t* __restrict__ input, const float* __restrict__ scale_ub,
|
|
int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int tid = threadIdx.x;
|
|
|
|
// Use int64 to avoid overflowing an int32 when calculating this offset
|
|
int64_t in_offset = static_cast<int64_t>(token_idx) * in_row_stride;
|
|
int64_t out_offset = static_cast<int64_t>(token_idx) * out_row_stride;
|
|
const scalar_t* token_in = input + in_offset;
|
|
fp8_type* token_out = out + out_offset;
|
|
|
|
// 1) per-token absmax
|
|
float absmax_val = 0.f;
|
|
vectorize_read_with_alignment<16>(
|
|
token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) {
|
|
absmax_val = fmaxf(absmax_val, fabsf(static_cast<float>(v)));
|
|
});
|
|
|
|
using BlockReduce = cub::BlockReduce<float, 256>;
|
|
__shared__ typename BlockReduce::TempStorage tmp;
|
|
const float block_max =
|
|
BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x);
|
|
|
|
__shared__ float token_scale;
|
|
if (tid == 0) {
|
|
token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max;
|
|
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
|
min_scaling_factor<fp8_type>::val());
|
|
scale[token_idx] = token_scale;
|
|
}
|
|
__syncthreads();
|
|
|
|
// 2) quantize
|
|
vectorize_with_alignment<16>(
|
|
token_in, token_out, hidden_size, tid, blockDim.x,
|
|
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
|
dst = scaled_fp8_conversion<false, fp8_type>(static_cast<float>(src),
|
|
token_scale);
|
|
});
|
|
}
|
|
|
|
} // namespace vllm
|
|
|
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
torch::Tensor const& input, // [..., d]
|
|
torch::Tensor const& scale) // [1]
|
|
{
|
|
TORCH_CHECK(input.stride(-1) == 1,
|
|
"last dimension of input must be contiguous");
|
|
TORCH_CHECK(out.stride(-1) == 1,
|
|
"last dimension of output must be contiguous");
|
|
|
|
const int hidden_size = input.size(-1);
|
|
const int num_tokens = input.numel() / hidden_size;
|
|
const int block_size = 256;
|
|
dim3 grid(num_tokens);
|
|
dim3 block(block_size);
|
|
|
|
const int64_t in_row_stride = input.stride(-2);
|
|
const int64_t out_row_stride = out.stride(-2);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
|
VLLM_DISPATCH_FP8_TYPES(
|
|
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
|
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
|
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
|
out_row_stride);
|
|
});
|
|
});
|
|
}
|
|
|
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
torch::Tensor const& input, // [..., d]
|
|
torch::Tensor& scale) // [1]
|
|
{
|
|
TORCH_CHECK(input.stride(-1) == 1,
|
|
"last dimension of input must be contiguous");
|
|
TORCH_CHECK(out.stride(-1) == 1,
|
|
"last dimension of output must be contiguous");
|
|
|
|
const int hidden_size = input.size(-1);
|
|
const int num_tokens = input.numel() / hidden_size;
|
|
const int block_size = 256;
|
|
dim3 grid(num_tokens);
|
|
dim3 block(block_size);
|
|
|
|
const int64_t in_row_stride = input.stride(-2);
|
|
const int64_t out_row_stride = out.stride(-2);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
// scale tensor should be initialised to <=0 before reduction
|
|
AT_CUDA_CHECK(
|
|
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
|
VLLM_DISPATCH_FP8_TYPES(
|
|
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
|
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
|
|
hidden_size, in_row_stride,
|
|
static_cast<int64_t>(num_tokens));
|
|
|
|
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
|
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
|
out_row_stride);
|
|
});
|
|
});
|
|
}
|
|
|
|
void dynamic_per_token_scaled_fp8_quant(
|
|
torch::Tensor& out, // [..., d]
|
|
torch::Tensor const& input, // [..., d]
|
|
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
|
TORCH_CHECK(input.stride(-1) == 1,
|
|
"last dimension of input must be contiguous");
|
|
TORCH_CHECK(out.stride(-1) == 1,
|
|
"last dimension of output must be contiguous");
|
|
|
|
const int hidden_size = input.size(-1);
|
|
const int num_tokens = input.numel() / hidden_size;
|
|
const int block_size = 256;
|
|
dim3 grid(num_tokens);
|
|
dim3 block(std::min(hidden_size, block_size));
|
|
|
|
const int64_t in_row_stride = input.stride(-2);
|
|
const int64_t out_row_stride = out.stride(-2);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(),
|
|
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
|
VLLM_DISPATCH_FP8_TYPES(
|
|
out.scalar_type(),
|
|
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
|
|
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
|
|
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
|
input.data_ptr<scalar_t>(),
|
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
|
hidden_size, in_row_stride, out_row_stride);
|
|
});
|
|
});
|
|
}
|