#include "common.cuh" #include "dispatch_utils.h" #include "../../cub_helpers.h" #include "../vectorization_utils.cuh" #include #include namespace vllm { template __global__ void scaled_fp8_quant_kernel_strided( fp8_type* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { const int64_t token_idx = blockIdx.x; // one token per block const int tid = threadIdx.x; const scalar_t* token_in = input + token_idx * in_row_stride; fp8_type* token_out = out + token_idx * out_row_stride; const float inv_scale = 1.0f / (*scale); vectorize_with_alignment<16>( token_in, token_out, hidden_size, tid, blockDim.x, [=] __device__(fp8_type & dst, const scalar_t& src) { dst = scaled_fp8_conversion(static_cast(src), inv_scale); }); } template __global__ void segmented_max_reduction_strided( float* __restrict__ scale, const scalar_t* __restrict__ input, int hidden_size, int64_t in_row_stride, int64_t num_tokens) { __shared__ float cache[256]; const int tid = threadIdx.x; int64_t token_idx = blockIdx.x; // one block per token. Guard in case gridDim.x > num_tokens. if (token_idx >= num_tokens) { return; } const scalar_t* row_ptr = input + token_idx * in_row_stride; // each thread scans elements of the row in a strided fashion. float thread_max = 0.0f; for (int e = tid; e < hidden_size; e += blockDim.x) { float v = fabsf(static_cast(row_ptr[e])); thread_max = fmaxf(thread_max, v); } cache[tid] = thread_max; __syncthreads(); // parallel reduction to find row max. for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { if (tid < offset) { cache[tid] = fmaxf(cache[tid], cache[tid + offset]); } __syncthreads(); } // thread 0 updates global scale (per-tensor) atomically. if (tid == 0) { atomicMaxFloat(scale, cache[0] / quant_type_max_v); } } template __global__ void scaled_fp8_quant_kernel_strided_dynamic( fp8_type* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { const int64_t token_idx = blockIdx.x; const int tid = threadIdx.x; const scalar_t* token_in = input + token_idx * in_row_stride; fp8_type* token_out = out + token_idx * out_row_stride; const float reciprocal_scale = 1.0f / (*scale); vectorize_with_alignment<16>( token_in, token_out, hidden_size, tid, blockDim.x, [=] __device__(fp8_type & dst, const scalar_t& src) { dst = scaled_fp8_conversion(static_cast(src), reciprocal_scale); }); } template __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( fp8_type* __restrict__ out, float* __restrict__ scale, const scalar_t* __restrict__ input, const float* __restrict__ scale_ub, int hidden_size, int64_t in_row_stride, int64_t out_row_stride) { const int64_t token_idx = blockIdx.x; const int tid = threadIdx.x; // Use int64 to avoid overflowing an int32 when calculating this offset int64_t in_offset = static_cast(token_idx) * in_row_stride; int64_t out_offset = static_cast(token_idx) * out_row_stride; const scalar_t* token_in = input + in_offset; fp8_type* token_out = out + out_offset; // 1) per-token absmax float absmax_val = 0.f; vectorize_read_with_alignment<16>( token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) { absmax_val = fmaxf(absmax_val, fabsf(static_cast(v))); }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; const float block_max = BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max; token_scale = fmaxf(token_scale / quant_type_max_v, min_scaling_factor::val()); scale[token_idx] = token_scale; } __syncthreads(); // 2) quantize vectorize_with_alignment<16>( token_in, token_out, hidden_size, tid, blockDim.x, [=] __device__(fp8_type & dst, const scalar_t& src) { dst = scaled_fp8_conversion(static_cast(src), token_scale); }); } } // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { TORCH_CHECK(input.stride(-1) == 1, "last dimension of input must be contiguous"); TORCH_CHECK(out.stride(-1) == 1, "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; const int block_size = 256; dim3 grid(num_tokens); dim3 block(block_size); const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { vllm::scaled_fp8_quant_kernel_strided <<>>( out.data_ptr(), input.data_ptr(), scale.data_ptr(), hidden_size, in_row_stride, out_row_stride); }); }); } void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { TORCH_CHECK(input.stride(-1) == 1, "last dimension of input must be contiguous"); TORCH_CHECK(out.stride(-1) == 1, "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; const int block_size = 256; dim3 grid(num_tokens); dim3 block(block_size); const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // scale tensor should be initialised to <=0 before reduction AT_CUDA_CHECK( cudaMemsetAsync(scale.data_ptr(), 0, sizeof(float), stream)); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { vllm::segmented_max_reduction_strided <<>>( scale.data_ptr(), input.data_ptr(), hidden_size, in_row_stride, static_cast(num_tokens)); vllm::scaled_fp8_quant_kernel_strided_dynamic <<>>( out.data_ptr(), input.data_ptr(), scale.data_ptr(), hidden_size, in_row_stride, out_row_stride); }); }); } void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scales, std::optional const& scale_ub) { TORCH_CHECK(input.stride(-1) == 1, "last dimension of input must be contiguous"); TORCH_CHECK(out.stride(-1) == 1, "last dimension of output must be contiguous"); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; const int block_size = 256; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, block_size)); const int64_t in_row_stride = input.stride(-2); const int64_t out_row_stride = out.stride(-2); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided< scalar_t, fp8_t><<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, hidden_size, in_row_stride, out_row_stride); }); }); }