diff --git a/CMakeLists.txt b/CMakeLists.txt index 180b896a7abac..ba3991372fabe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -257,8 +257,8 @@ set(VLLM_EXT_SRC "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" - "csrc/quantization/fp8/common.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" @@ -302,13 +302,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/quantization/fp8/per_token_group_quant.cu") + "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/quantization/w8a8/int8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -412,11 +413,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -443,9 +444,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -473,9 +474,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -506,7 +507,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") @@ -617,7 +618,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -637,7 +638,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -658,7 +659,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # moe_data.cu is used by all CUTLASS MoE kernels. cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") @@ -677,7 +678,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 57382c1ddc65b..052ff168cec4f 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -28,10 +28,10 @@ #ifdef USE_ROCM #include - #include "../quantization/fp8/amd/quant_utils.cuh" + #include "../quantization/w8a8/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; #else - #include "../quantization/fp8/nvidia/quant_utils.cuh" + #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 80b4c47c55476..03db59ec9d0ee 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -9,9 +9,9 @@ #include "quantization/vectorization_utils.cuh" #ifdef USE_ROCM - #include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" #else - #include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #include diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index be134089bd6d4..c8c55fb6d21fc 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -6,7 +6,7 @@ */ #include "type_convert.cuh" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 9aa1411b4a25c..b1d31907e20bf 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -7,7 +7,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 4e6118e52e8d6..2b1eb1d568e4e 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -6,7 +6,7 @@ #include "quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" namespace vllm { diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/w8a8/cutlass/Epilogues.md similarity index 100% rename from csrc/quantization/cutlass_w8a8/Epilogues.md rename to csrc/quantization/w8a8/cutlass/Epilogues.md diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh rename to csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh rename to csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_data.cu rename to csrc/quantization/w8a8/cutlass/moe/moe_data.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/amd/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/amd/quant_utils.cuh index e51a4e14e518f..81f5cb83f3e18 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh @@ -5,7 +5,7 @@ #include #include -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/w8a8/fp8/common.cu similarity index 99% rename from csrc/quantization/fp8/common.cu rename to csrc/quantization/w8a8/fp8/common.cu index 45d6d5082ce49..622d559d839c5 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/w8a8/fp8/common.cu @@ -1,7 +1,7 @@ #include "common.cuh" #include "dispatch_utils.h" -#include "../../cub_helpers.h" -#include "../vectorization_utils.cuh" +#include "quantization/vectorization_utils.cuh" +#include "cub_helpers.h" #include #include diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh similarity index 100% rename from csrc/quantization/fp8/common.cuh rename to csrc/quantization/w8a8/fp8/common.cuh diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/nvidia/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh index 5b9c2df8468cb..60e023f9a2c71 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" #include #include #include diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu similarity index 98% rename from csrc/quantization/fp8/per_token_group_quant.cu rename to csrc/quantization/w8a8/fp8/per_token_group_quant.cu index f5b40e35b6e5a..52326687f6197 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -1,6 +1,6 @@ #include -#include "../per_token_group_quant_8bit.h" +#include "quantization/w8a8/per_token_group_quant_8bit.h" #include @@ -8,9 +8,9 @@ #include -#include "../vectorization.cuh" -#include "../vectorization_utils.cuh" -#include "../../dispatch_utils.h" +#include "quantization/vectorization.cuh" +#include "quantization/vectorization_utils.cuh" +#include "dispatch_utils.h" __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { unsigned mask = 0xffff; @@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input, double fp8_max, bool scale_ue8m0) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); -} +} \ No newline at end of file diff --git a/csrc/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/quantization/w8a8/int8/per_token_group_quant.cu new file mode 100644 index 0000000000000..9d808a176f538 --- /dev/null +++ b/csrc/quantization/w8a8/int8/per_token_group_quant.cu @@ -0,0 +1,12 @@ +#include +#include + +#include "quantization/w8a8/per_token_group_quant_8bit.h" + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} \ No newline at end of file diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu similarity index 94% rename from csrc/quantization/compressed_tensors/int8_quant_kernels.cu rename to csrc/quantization/w8a8/int8/scaled_quant.cu index bcfde9fbcbbef..7fe9e96bfb017 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,15 +1,11 @@ #include #include -#ifndef USE_ROCM - #include "../per_token_group_quant_8bit.h" -#endif - #include -#include "../../cub_helpers.h" -#include "../../dispatch_utils.h" -#include "../vectorization_utils.cuh" +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" +#include "cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -25,7 +21,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -84,7 +79,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -176,7 +170,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -194,7 +187,6 @@ struct MinMax { __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} - // add a value to the MinMax __host__ __device__ MinMax& operator+=(float v) { min = fminf(min, v); max = fmaxf(max, v); @@ -228,7 +220,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; - // 1. calculate min & max MinMax thread_mm; vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -261,7 +252,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const float inv_s = 1.f / scale_sh; const azp_t azp = azp_sh; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -332,14 +322,4 @@ void dynamic_scaled_int8_quant( hidden_size); } }); -} - -#ifndef USE_ROCM -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max) { - per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, - int8_min, int8_max); -} -#endif +} \ No newline at end of file diff --git a/csrc/quantization/per_token_group_quant_8bit.h b/csrc/quantization/w8a8/per_token_group_quant_8bit.h similarity index 84% rename from csrc/quantization/per_token_group_quant_8bit.h rename to csrc/quantization/w8a8/per_token_group_quant_8bit.h index 537b61bc4303f..25d4ecd1131a1 100644 --- a/csrc/quantization/per_token_group_quant_8bit.h +++ b/csrc/quantization/w8a8/per_token_group_quant_8bit.h @@ -1,7 +1,6 @@ #pragma once #include -// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders // 8-bit per-token-group quantization helper used by both FP8 and INT8 void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q, diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index dac9df6048f2a..254633e9ac040 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -23,7 +23,7 @@ #include #include "../attention/dtype_fp8.cuh" -#include "../quantization/fp8/amd/quant_utils.cuh" +#include "../quantization/w8a8/fp8/amd/quant_utils.cuh" #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index eb47139208c91..b8a1b439758c1 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -11,7 +11,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2f982f96b0d04..321084d9754c0 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -88,8 +88,8 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/w8a8/cutlass/Epilogues.md if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)