diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index eceb3a8ea05da..f3f9f669e00a4 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( fp8_type* __restrict__ token_output = &out[offset]; // For vectorization, token_input and token_output pointers need to be - // aligned at 8-byte and 4-byte addresses respectively. - bool const can_vectorize = hidden_size % 4 == 0; + // aligned at 32-byte and 16-byte addresses respectively. + bool const can_vectorize = hidden_size % 16 == 0; float absmax_val = 0.0f; if (can_vectorize) { @@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } else { for (int i = tid; i < hidden_size; i += blockDim.x) { float const x = static_cast(token_input[i]); - absmax_val = max(absmax_val, fabs(x)); + absmax_val = fmaxf(absmax_val, fabsf(x)); } } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStorage; float const block_absmax_val_maybe = BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { if (scale_ub) { - token_scale = min(block_absmax_val_maybe, *scale_ub); + token_scale = fminf(block_absmax_val_maybe, *scale_ub); } else { token_scale = block_absmax_val_maybe; } // token scale computation - token_scale = max(token_scale / quant_type_max_v, - min_scaling_factor::val()); + token_scale = fmaxf(token_scale / quant_type_max_v, + min_scaling_factor::val()); scale[token_idx] = token_scale; } __syncthreads(); @@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { - int64_t num_tokens = input.numel() / input.size(-1); - int64_t num_elems = input.numel(); - dim3 grid(num_tokens); - dim3 block(1024); + int const block_size = 256; + int const num_tokens = input.numel() / input.size(-1); + int const num_elems = input.numel(); + dim3 const grid(num_tokens); + dim3 const block(block_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( @@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { - int64_t num_tokens = input.numel() / input.size(-1); - int64_t num_elems = input.numel(); - dim3 grid(num_tokens); - dim3 block(1024); + int const block_size = 256; + int const num_tokens = input.numel() / input.size(-1); + int const num_elems = input.numel(); + dim3 const grid(num_tokens); + dim3 const block(block_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( @@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant( int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; + int const block_size = 256; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, block_size)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index def8b31b27546..d36f94a8f10d6 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, } float r = - fmax(-quant_type_max_v, fmin(x, quant_type_max_v)); + fmaxf(-quant_type_max_v, fminf(x, quant_type_max_v)); #ifndef USE_ROCM return static_cast(r); #else @@ -65,7 +65,7 @@ template __global__ void segmented_max_reduction(float* __restrict__ scale, const scalar_t* __restrict__ input, int64_t num_elems) { - __shared__ float cache[1024]; + __shared__ float cache[256]; int64_t i = blockDim.x * blockIdx.x + threadIdx.x; // First store maximum for all values processes by @@ -73,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, scalar_t tmp = 0.0; while (i < num_elems) { float x = static_cast(input[i]); - tmp = max(tmp, fabs(x)); + tmp = fmaxf(tmp, fabsf(x)); i += blockDim.x * gridDim.x; } cache[threadIdx.x] = tmp; @@ -100,25 +100,27 @@ template __device__ float thread_max_vec(scalar_t const* __restrict__ input, int64_t const num_elems, int const tid, int const step) { + constexpr size_t VEC_SIZE = 16; + using scalarxN_t = vec_n_t; // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); + auto const* vectorized_in = reinterpret_cast(input); - int64_t const num_vec_elems = num_elems >> 2; + // num_elems / VEC_SIZE (which is 16) + int64_t const num_vec_elems = num_elems >> 4; float absmax_val = 0.0f; -#pragma unroll 4 +#pragma unroll for (int64_t i = tid; i < num_vec_elems; i += step) { - vec4_t in_vec = vectorized_in[i]; - absmax_val = max(absmax_val, fabs(in_vec.x)); - absmax_val = max(absmax_val, fabs(in_vec.y)); - absmax_val = max(absmax_val, fabs(in_vec.z)); - absmax_val = max(absmax_val, fabs(in_vec.w)); + scalarxN_t in_vec = vectorized_in[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j])); + } } - // Handle the remaining elements if num_elems is not divisible by 4 - for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - absmax_val = max(absmax_val, fabs(input[i])); + // Handle the remaining elements if num_elems is not divisible by VEC_SIZE + for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { + absmax_val = fmaxf(absmax_val, fabsf(input[i])); } return absmax_val; @@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out, float const scale, int64_t const num_elems, int const tid, int const step) { - using float8x4_t = q8x4_t; + constexpr size_t VEC_SIZE = 16; + using scalarxN_t = vec_n_t; + using float8xN_t = q8_n_t; // Vectorized input/output to better utilize memory bandwidth. - auto const* vectorized_in = reinterpret_cast const*>(input); - auto* vectorized_out = reinterpret_cast(out); + auto const* vectorized_in = reinterpret_cast(input); + auto* vectorized_out = reinterpret_cast(out); - int64_t const num_vec_elems = num_elems >> 2; + // num_elems / VEC_SIZE (which is 16) + int64_t const num_vec_elems = num_elems >> 4; -#pragma unroll 4 +#pragma unroll for (int64_t i = tid; i < num_vec_elems; i += step) { - vec4_t in_vec = vectorized_in[i]; - float8x4_t out_vec; + scalarxN_t in_vec = vectorized_in[i]; + float8xN_t out_vec; - out_vec.x = scaled_fp8_conversion( - static_cast(in_vec.x), scale); - out_vec.y = scaled_fp8_conversion( - static_cast(in_vec.y), scale); - out_vec.z = scaled_fp8_conversion( - static_cast(in_vec.z), scale); - out_vec.w = scaled_fp8_conversion( - static_cast(in_vec.w), scale); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + out_vec.val[j] = scaled_fp8_conversion( + static_cast(in_vec.val[j]), scale); + } vectorized_out[i] = out_vec; } - // Handle the remaining elements if num_elems is not divisible by 4 - for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + // Handle the remaining elements if num_elems is not divisible by VEC_SIZE + for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) { out[i] = scaled_fp8_conversion( static_cast(input[i]), scale); } diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index e6d23cd24e178..3f188872d80d3 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -140,6 +140,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // sum of squares float ss = 0.0f; + const int VEC_SIZE = 4; int32_t const num_vec_elems = hidden_size >> 2; #pragma unroll 4 @@ -147,22 +148,23 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, vec4_t in = vec_input[i]; vec4_t x; - x.x = static_cast(in.x); - x.y = static_cast(in.y); - x.z = static_cast(in.z); - x.w = static_cast(in.w); - if constexpr (has_residual) { - vec4_t r = vec_residual[i]; - x.x += static_cast(r.x); - x.y += static_cast(r.y); - x.z += static_cast(r.z); - x.w += static_cast(r.w); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); } - ss += x.x * x.x; - ss += x.y * x.y; - ss += x.z * x.z; - ss += x.w * x.w; + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + ss += x.val[j] * x.val[j]; + } } using BlockReduce = cub::BlockReduce; @@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales( constexpr scalar_out_t qmax{quant_type_max_v}; + const int VEC_SIZE = 4; int32_t const num_vec_elems = hidden_size >> 2; float block_absmax_val_maybe = 0.0f; @@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales( vec4_t const w = vec_weight[i]; vec4_t x; - x.x = static_cast(in.x); - x.y = static_cast(in.y); - x.z = static_cast(in.z); - x.w = static_cast(in.w); - if constexpr (has_residual) { - vec4_t r = vec_residual[i]; - x.x += static_cast(r.x); - x.y += static_cast(r.y); - x.z += static_cast(r.z); - x.w += static_cast(r.w); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); } - block_absmax_val_maybe = fmaxf( - block_absmax_val_maybe, fabs(static_cast(x.x * rms) * w.x)); - block_absmax_val_maybe = fmaxf( - block_absmax_val_maybe, fabs(static_cast(x.y * rms) * w.y)); - block_absmax_val_maybe = fmaxf( - block_absmax_val_maybe, fabs(static_cast(x.z * rms) * w.z)); - block_absmax_val_maybe = fmaxf( - block_absmax_val_maybe, fabs(static_cast(x.w * rms) * w.w)); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } } using BlockReduce = cub::BlockReduce; @@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, vec_residual = reinterpret_cast*>(&residual[token_offset]); } + const int VEC_SIZE = 4; int32_t const num_vec_elems = hidden_size >> 2; // TODO(luka/varun) extract into type-agnostic vectorized quant function to @@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, vec4_t const w = vec_weight[i]; vec4_t x; - x.x = static_cast(in.x); - x.y = static_cast(in.y); - x.z = static_cast(in.z); - x.w = static_cast(in.w); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); + } + if constexpr (has_residual) { vec4_t r = vec_residual[i]; - x.x += static_cast(r.x); - x.y += static_cast(r.y); - x.z += static_cast(r.z); - x.w += static_cast(r.w); - // Update residual - r.x = static_cast(x.x); - r.y = static_cast(x.y); - r.z = static_cast(x.z); - r.w = static_cast(x.w); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } +// Update residual +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + r.val[j] = static_cast(x.val[j]); + } vec_residual[i] = r; } q8x4_t out; - out.x = ScaledQuant::quant_fn( - static_cast(x.x * rms) * w.x, scale); - out.y = ScaledQuant::quant_fn( - static_cast(x.y * rms) * w.y, scale); - out.z = ScaledQuant::quant_fn( - static_cast(x.z * rms) * w.z, scale); - out.w = ScaledQuant::quant_fn( - static_cast(x.w * rms) * w.w, scale); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + out.val[j] = ScaledQuant::quant_fn( + static_cast(x.val[j] * rms) * w.val[j], scale); + } vec_output[i] = out; } } diff --git a/csrc/quantization/vectorization.cuh b/csrc/quantization/vectorization.cuh index 866da10b5bc14..11d57a5fafe89 100644 --- a/csrc/quantization/vectorization.cuh +++ b/csrc/quantization/vectorization.cuh @@ -10,23 +10,22 @@ namespace vllm { // Vectorization containers -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; +template +struct __align__(vec_size * sizeof(scalar_t)) vec_n_t { + scalar_t val[vec_size]; }; -template -struct __align__(4) q8x4_t { +template +struct __align__(vec_size * sizeof(quant_type_t)) q8_n_t { static_assert(std::is_same_v || std::is_same_v || std::is_same_v); - quant_type_t x; - quant_type_t y; - quant_type_t z; - quant_type_t w; + quant_type_t val[vec_size]; }; +template +using vec4_t = vec_n_t; +template +using q8x4_t = q8_n_t; + } // namespace vllm