diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 90a5e54736cf3..41d9e682572a6 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 100 +MAX_SIZE_MB = 150 def print_top_10_largest_files(zip_file): diff --git a/CMakeLists.txt b/CMakeLists.txt index 47629f036fb09..1c7dfe0c048b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,7 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/fp8/fp8_cuda_kernels.cu" + "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 7c71673e36f29..00c81e4d00ad8 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(APPEND GPU_FLAGS "-DENABLE_FP8") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8_E4M3" + "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8b1b5e098015f..41b337dd91d36 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -19,21 +19,17 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" -#if defined(ENABLE_FP8_E5M2) -#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "../quantization/fp8/amd_detail/quant_utils.cuh" -#endif - -#include - #ifdef USE_ROCM #include + #include "../quantization/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM @@ -92,7 +88,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -157,9 +153,7 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; -#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -223,21 +217,14 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. - k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); -#elif defined(ENABLE_FP8_E4M3) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k - // cache vec to k vec in higher precision (FP16, BFloat16, etc.) - k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); -#else - assert(false); -#endif - } else { + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); } } @@ -312,9 +299,7 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; -#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; @@ -348,21 +333,13 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); -#elif defined(ENABLE_FP8_E4M3) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert - // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) - v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); -#else - assert(false); -#endif - } else { - v_vec = *reinterpret_cast(v_ptr + offset); + v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, @@ -448,7 +425,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE> + vllm::Fp8KVCacheDataType KV_DTYPE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -464,7 +441,7 @@ __global__ void paged_attention_v1_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -477,7 +454,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -496,7 +473,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -606,9 +583,9 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + KV_DTYPE>), shared_mem_size); \ vllm::paged_attention_v1_kernel<<>>( \ + KV_DTYPE><<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -629,7 +606,7 @@ template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -706,36 +683,36 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( @@ -752,65 +729,44 @@ void paged_attention_v1( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -897,39 +853,39 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( @@ -949,29 +905,7 @@ void paged_attention_v2( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index d11dee91ebe87..2b32ce372a64f 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,14 +3,21 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8_E5M2 +#ifdef ENABLE_FP8 +#ifndef USE_ROCM #include -#endif +#endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) -// fp8 vector types for quantization of kv cache +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache template<> struct Vec { using Type = uint8_t; @@ -30,6 +37,5 @@ template<> struct Vec { using Type = uint2; }; -#endif // ENABLE_FP8_E5M2 } // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 212a3bf3ddc1c..8c176c452425e 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -34,5 +34,7 @@ void reshape_and_cache_flash( // Just for unittest void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache); + const float scale, + const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 76db96f099c69..e5b74da6ad068 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,10 +4,11 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#if defined(ENABLE_FP8_E5M2) -#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "quantization/fp8/amd_detail/quant_utils.cuh" + +#ifdef USE_ROCM +#include "quantization/fp8/amd/quant_utils.cuh" +#else +#include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -149,7 +150,7 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] @@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel( + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_kv_cache) { -#if defined(ENABLE_FP8_E5M2) - key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); - value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); -#elif defined(ENABLE_FP8_E4M3) - key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); -#else - assert(false); -#endif - } else { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); } } } @@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel( } } // namespace vllm -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x, \ kv_scale); void reshape_and_cache( @@ -285,25 +282,8 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, float, false); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); - } - } else if (kv_cache_dtype == "fp8") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, uint8_t, true); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( @@ -353,35 +333,34 @@ void reshape_and_cache_flash( namespace vllm { -template +template __global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, + const float kv_scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#if defined(ENABLE_FP8_E5M2) - dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); -#elif defined(ENABLE_FP8_E4M3) - dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); -#else - assert(false); -#endif + dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); } } } // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + kv_scale, \ block_stride); +// Only for testing. void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache) + const float kv_scale, + const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); @@ -399,17 +378,35 @@ void convert_fp8( dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8.h rename to csrc/quantization/fp8/amd/hip_float8.h diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8_impl.h rename to csrc/quantization/fp8/amd/hip_float8_impl.h diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh similarity index 81% rename from csrc/quantization/fp8/amd_detail/quant_utils.cuh rename to csrc/quantization/fp8/amd/quant_utils.cuh index 894160972d9f4..df0329f79d361 100644 --- a/csrc/quantization/fp8/amd_detail/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,12 +5,17 @@ #include #include +#include "../../../attention/dtype_fp8.cuh" #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" namespace vllm { -namespace fp8_e4m3 { +#ifdef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + template __inline__ __device__ Tout vec_conversion(const Tin& x) { @@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion(const uint3 float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } +#endif // ENABLE_FP8 +template +__inline__ __device__ Tout convert(const Tin &x) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } +#endif + assert(false); } + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // fp8 +#endif // USE_ROCM } // namespace vllm diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/common.cu similarity index 100% rename from csrc/quantization/fp8/fp8_cuda_kernels.cu rename to csrc/quantization/fp8/common.cu diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh new file mode 100644 index 0000000000000..4eeacf7a6f9d9 --- /dev/null +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -0,0 +1,568 @@ +#pragma once + +#include "../../../attention/attention_dtypes.h" +#include +#include +#include +#include + +namespace vllm { +#ifndef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + +#if 0 // Disable the following code to reduce the binary size. +template +__inline__ __device__ Tout +vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); + tmp.u32[1] = + vec_conversion((uint16_t)(a >> 16U), fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x, fp8_type); + tmp.u64[1] = vec_conversion(a.y, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); + res.y = + vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float +vec_conversion(const uint8_t &a, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + uint16_t tmp = vec_conversion(a, fp8_type); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = vec_conversion((uint16_t)a, fp8_type); + res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const float &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = vec_conversion(a, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val, fp8_type); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val, fp8_type); + + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint4 b; + b.x = vec_conversion(a.x, fp8_type); + b.y = vec_conversion(a.y, fp8_type); + b.z = vec_conversion(a.z, fp8_type); + b.w = vec_conversion(a.w, fp8_type); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_8_t b; + from_float(b, a); + return b; +} +#endif + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains Convention of the scale in API, e.g: FP8_data = + Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + */ + +template +__inline__ __device__ Tout scaled_vec_conversion( + const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return float_to_half(half_to_float(tmp.x) * scale); +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); + tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + scaled_vec_conversion((uint16_t)a, scale, fp8_type); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), + scale, fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); + tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), + scale, fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale, fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + uint16_t tmp = res.x; + + // half -> float + return half_to_float(tmp) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, + fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, + __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const float &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} +#endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin &x) { +#if 0 // Disable the following code to reduce the binary size. + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return vec_conversion(x, __NV_E5M2); + } +#endif + assert(false); +} + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return scaled_vec_conversion(x, scale, __NV_E5M2); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh deleted file mode 100644 index 9bcab25db03cf..0000000000000 --- a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh +++ /dev/null @@ -1,277 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "../../attention/attention_dtypes.h" -#include "../../attention/dtype_float32.cuh" -#include "../../attention/dtype_float16.cuh" -#include "../../attention/dtype_bfloat16.cuh" - - -namespace vllm { -#ifdef ENABLE_FP8_E5M2 -namespace fp8_e5m2_unscaled { - -template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; -} - -// fp8 -> half -template<> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - return res.x; -} - -// fp8x2 -> half2 -template<> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); - tmp.u16[0] = res.x; - tmp.u16[1] = res.y; - return tmp.u32; -} - -// fp8x4 -> half2x2 -template<> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template<> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; -} - -// fp8 -> __nv_bfloat16 -template<> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - // Note there is no direct convert function from fp8 to bf16. - // fp8 -> half - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - // half -> float -> bf16 - float tmp = half_to_float(res.x); - return __float2bfloat16(tmp); -} - -// fp8x2 -> __nv_bfloat162 -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; -} - -// fp8x4 -> bf16_4_t -template<> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> bf16_8_t -template<> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - -// fp8 -> float -template<> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - // fp8 -> half - uint16_t tmp = vec_conversion(a); - // half -> float - return half_to_float(tmp); -} - -// fp8x2 -> float2 -template<> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ - // fp8x2 -> half2 - uint32_t tmp = vec_conversion(a); - // half2 -> float2 - return half2_to_float2(tmp); -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> float8 -template<> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - - -// half -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; - __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// bf16 -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - assert(false); -#else - __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -#endif -} - -// float -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; -} - - -template<> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; - - float16 = __float22half2_rn(a); - return uint32; -} - -template<> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); - - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - - return b; -} - -template<> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; -} - -template<> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; -} - -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { - __nv_bfloat162 b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { - bf16_4_t b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { - bf16_8_t b; - from_float(b, a); - return b; -} - -} // namespace fp8_e5m2_unscaled -#endif // ENABLE_FP8_E5M2 -} // namespace vllm diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 84539205e0ae3..28496f187d466 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -236,14 +236,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4cae15c79c489..9f0cb60dc16e2 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,8 +5,6 @@ import pytest import torch from vllm import _custom_ops as ops -from vllm._C import cache_ops -from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,6 +23,8 @@ SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + +# We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -124,8 +124,6 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -149,9 +147,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(cloned_key_cache, key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(cloned_value_cache, value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -165,9 +163,9 @@ def test_reshape_and_cache( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(result_key_cache, key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -255,8 +253,8 @@ def test_reshape_and_cache_flash( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') @@ -299,8 +297,6 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -348,7 +344,6 @@ def test_swap_blocks( dist_value_caches[0][dst].cpu()) -@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -357,7 +352,7 @@ def test_swap_blocks( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_fp8_conversion( +def test_fp8_e4m3_conversion( num_heads: int, head_size: int, block_size: int, @@ -377,9 +372,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache_fp8, cache) converted_cache = torch.empty_like(cache) - ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(converted_cache, cache_fp8) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 829c47003ad0e..35a9f6329fc42 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor, vllm_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: - vllm_cache_ops.convert_fp8(output, input) +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) #TODO: cuda_utils, custom_ar diff --git a/vllm/utils.py b/vllm/utils.py index 6479a8dab320a..f0e71f5e99b64 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -329,7 +329,7 @@ def _generate_random_fp8( from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp