diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ebea1c42e9a..66967b655a1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -269,8 +269,8 @@ set(VLLM_EXT_SRC "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" - "csrc/quantization/fp8/common.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" @@ -314,12 +314,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/quantization/fp8/per_token_group_quant.cu") + "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/quantization/w8a8/int8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -423,11 +424,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -458,9 +459,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -492,9 +493,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -525,7 +526,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") @@ -648,7 +649,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -672,7 +673,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -697,7 +698,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") @@ -720,7 +721,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 57382c1ddc65..052ff168cec4 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -28,10 +28,10 @@ #ifdef USE_ROCM #include - #include "../quantization/fp8/amd/quant_utils.cuh" + #include "../quantization/w8a8/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; #else - #include "../quantization/fp8/nvidia/quant_utils.cuh" + #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f4b116c94f19..0aa0dc14c748 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -9,9 +9,9 @@ #include "quantization/vectorization_utils.cuh" #ifdef USE_ROCM - #include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" #else - #include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #include diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h index 470a63a22cab..18e4e343ad8b 100644 --- a/csrc/cub_helpers.h +++ b/csrc/cub_helpers.h @@ -12,6 +12,7 @@ using CubMaxOp = cub::Max; #endif // CUB_VERSION #else #include -using CubAddOp = cub::Sum; -using CubMaxOp = cub::Max; +namespace cub = hipcub; +using CubAddOp = hipcub::Sum; +using CubMaxOp = hipcub::Max; #endif // USE_ROCM diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 58c3d9c0981a..0fc462194fcd 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -6,7 +6,7 @@ */ #include "type_convert.cuh" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index b94cc9ce5086..55da79a12d89 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -7,7 +7,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 4e6118e52e8d..2b1eb1d568e4 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -6,7 +6,7 @@ #include "quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" namespace vllm { diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/w8a8/cutlass/Epilogues.md similarity index 100% rename from csrc/quantization/cutlass_w8a8/Epilogues.md rename to csrc/quantization/w8a8/cutlass/Epilogues.md diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh rename to csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh rename to csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_data.cu rename to csrc/quantization/w8a8/cutlass/moe/moe_data.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/amd/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/amd/quant_utils.cuh index e51a4e14e518..81f5cb83f3e1 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh @@ -5,7 +5,7 @@ #include #include -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/w8a8/fp8/common.cu similarity index 99% rename from csrc/quantization/fp8/common.cu rename to csrc/quantization/w8a8/fp8/common.cu index 45d6d5082ce4..7a822fb8fb8a 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/w8a8/fp8/common.cu @@ -1,7 +1,7 @@ #include "common.cuh" #include "dispatch_utils.h" -#include "../../cub_helpers.h" -#include "../vectorization_utils.cuh" +#include "cub_helpers.h" +#include "quantization/vectorization_utils.cuh" #include #include diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh similarity index 100% rename from csrc/quantization/fp8/common.cuh rename to csrc/quantization/w8a8/fp8/common.cuh diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/nvidia/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh index 5361a8b1a598..421e8092474b 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" #include #include #include diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu similarity index 98% rename from csrc/quantization/fp8/per_token_group_quant.cu rename to csrc/quantization/w8a8/fp8/per_token_group_quant.cu index 91d489fdef86..e3ab0676b254 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -1,6 +1,6 @@ #include -#include "../per_token_group_quant_8bit.h" +#include "quantization/w8a8/per_token_group_quant_8bit.h" #include @@ -8,9 +8,9 @@ #include -#include "../vectorization.cuh" -#include "../vectorization_utils.cuh" -#include "../../dispatch_utils.h" +#include "quantization/vectorization.cuh" +#include "quantization/vectorization_utils.cuh" +#include "dispatch_utils.h" __device__ __forceinline__ float GroupReduceMax(float val) { unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; @@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input, double fp8_max, bool scale_ue8m0) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); -} +} \ No newline at end of file diff --git a/csrc/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/quantization/w8a8/int8/per_token_group_quant.cu new file mode 100644 index 000000000000..9d808a176f53 --- /dev/null +++ b/csrc/quantization/w8a8/int8/per_token_group_quant.cu @@ -0,0 +1,12 @@ +#include +#include + +#include "quantization/w8a8/per_token_group_quant_8bit.h" + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} \ No newline at end of file diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu similarity index 94% rename from csrc/quantization/compressed_tensors/int8_quant_kernels.cu rename to csrc/quantization/w8a8/int8/scaled_quant.cu index bcfde9fbcbbe..7fe9e96bfb01 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,15 +1,11 @@ #include #include -#ifndef USE_ROCM - #include "../per_token_group_quant_8bit.h" -#endif - #include -#include "../../cub_helpers.h" -#include "../../dispatch_utils.h" -#include "../vectorization_utils.cuh" +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" +#include "cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -25,7 +21,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -84,7 +79,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -176,7 +170,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -194,7 +187,6 @@ struct MinMax { __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} - // add a value to the MinMax __host__ __device__ MinMax& operator+=(float v) { min = fminf(min, v); max = fmaxf(max, v); @@ -228,7 +220,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; - // 1. calculate min & max MinMax thread_mm; vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -261,7 +252,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const float inv_s = 1.f / scale_sh; const azp_t azp = azp_sh; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -332,14 +322,4 @@ void dynamic_scaled_int8_quant( hidden_size); } }); -} - -#ifndef USE_ROCM -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max) { - per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, - int8_min, int8_max); -} -#endif +} \ No newline at end of file diff --git a/csrc/quantization/per_token_group_quant_8bit.h b/csrc/quantization/w8a8/per_token_group_quant_8bit.h similarity index 84% rename from csrc/quantization/per_token_group_quant_8bit.h rename to csrc/quantization/w8a8/per_token_group_quant_8bit.h index 537b61bc4303..25d4ecd1131a 100644 --- a/csrc/quantization/per_token_group_quant_8bit.h +++ b/csrc/quantization/w8a8/per_token_group_quant_8bit.h @@ -1,7 +1,6 @@ #pragma once #include -// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders // 8-bit per-token-group quantization helper used by both FP8 and INT8 void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q, diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index df3208a120f1..a339c5641bb4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -23,7 +23,7 @@ #include #include "../attention/dtype_fp8.cuh" -#include "../quantization/fp8/amd/quant_utils.cuh" +#include "../quantization/w8a8/fp8/amd/quant_utils.cuh" // ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent #if !defined(HIP_FP8_TYPE_OCP) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index e4600350d3ea..2ef579a1b753 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -11,7 +11,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index f1dafdf14c7a..13dbd55c32df 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -89,8 +89,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)