diff --git a/CMakeLists.txt b/CMakeLists.txt index d88ba3aa66303..e09972fe71995 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + + # marlin arches for fp16 output + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) + cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + if (MARLIN_ARCHS) # @@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MARLIN_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") - message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} - OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=$PYTHONPATH - ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log @@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "\nCheck the log for details: " "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") else() - set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} - CACHE STRING "Last run Marlin generate script hash" FORCE) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin generate script hash and arch" FORCE) message(STATUS "Marlin generation completed successfully.") endif() else() message(STATUS "Marlin generation script has not changed, skipping generation.") endif() - file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") @@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_BF16_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) + + if (MARLIN_FP8_ARCHS) + file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC}) + endif() + set(MARLIN_SRCS "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( @@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") - # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # moe marlin arches + # note that we always set `use_atomic_add=False` for moe marlin now, + # so we don't need 9.0 for bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # moe marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # @@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MOE_MARLIN_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") - message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} - OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=$PYTHONPATH - ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE moe_marlin_generation_result OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log @@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "\nCheck the log for details: " "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") else() - set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} CACHE STRING "Last run Marlin MOE generate script hash" FORCE) message(STATUS "Marlin MOE generation completed successfully.") endif() @@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Marlin MOE generation script has not changed, skipping generation.") endif() - file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") + file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") + list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") set_gencode_flags_for_srcs( - SRCS "${MOE_WNAA16_MARLIN_SRC}" + SRCS "${MARLIN_MOE_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} + set_source_files_properties(${MARLIN_MOE_SRC} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) - list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) + if (MARLIN_MOE_FP8_ARCHS) + file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_FP8_SRC}" + CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_FP8_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) + endif() message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 8787724d77cfb..ac78c019a59e5 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: b_q_weight=w_q, b_bias=None, b_scales=w_s, + a_scales=None, global_scale=None, b_zeros=w_zp, g_idx=g_idx, diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 12ca9214b1f95..48d790aec9e07 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -263,7 +263,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -273,7 +273,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/csrc/moe/marlin_moe_wna16/.gitignore b/csrc/moe/marlin_moe_wna16/.gitignore index 77088552b85b4..ba805f9250ece 100644 --- a/csrc/moe/marlin_moe_wna16/.gitignore +++ b/csrc/moe/marlin_moe_wna16/.gitignore @@ -1 +1,2 @@ -kernel_*.cu \ No newline at end of file +sm*_kernel_*.cu +kernel_selector.h diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index be5b68cc53e6f..88f1055337fd5 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -4,134 +4,282 @@ import glob import itertools import os import subprocess +import sys import jinja2 -FILE_HEAD = """ -// auto generated by generate.py -// clang-format off +ARCHS = [] +SUPPORT_FP8 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py +// clang-format off +""".lstrip() + +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { -""".strip() +""" +) TEMPLATE = ( "template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " + "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" + "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) -# int8 with zero point case (vllm::kU8) is also supported, -# we don't add it to reduce wheel size. -SCALAR_TYPES = [ - "vllm::kU4", - "vllm::kU4B8", - "vllm::kU8B128", - "vllm::kFE4M3fn", - "vllm::kFE2M1f", -] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] -# group_blocks: -# = 0 : act order case -# = -1 : channelwise quantization -# > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] -DTYPES = ["fp16", "bf16"] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # AWQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + result_dict = {} + + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": + continue + if "16" in a_type and "16" in c_type and a_type != c_type: + continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue + + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": "pipe_stages", + "group_blocks": group_blocks, + "is_zp_float": "false", + } + + result_dict[(a_type, b_type, c_type)].append(config) + + kernel_selector_str = FILE_HEAD_COMMENT + + for (a_type, b_type, c_type), config_list in result_dict.items(): all_template_str_list = [] - - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS - ): - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", - "vllm::kU8B128", - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue - - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: - continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: - continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: - continue - - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] - - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" - - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" - + for config in config_list: + s_type = config["s_type"] template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=False, + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + all_template_str_list.append(template_str) + + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) + + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " + + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" ) - all_template_str_list.append(template_str) + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " TORCH_CHECK(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + if __name__ == "__main__": remove_old_kernels() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 6190f7ee21ece..57f5a17932d44 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -11,8 +11,9 @@ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \ @@ -20,12 +21,13 @@ const float *__restrict__ topk_weights_ptr, int top_k, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ - bool use_fp32_reduce, int max_shared_mem + bool use_fp32_reduce namespace MARLIN_NAMESPACE_NAME { -template shared // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -76,8 +77,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce, // whether to use fp32 global reduce - int max_shared_mem) {} + bool use_fp32_reduce // whether to use fp32 global reduce +) {} } // namespace MARLIN_NAMESPACE_NAME @@ -85,65 +86,148 @@ __global__ void Marlin( // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } -template +template __device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); const uint32_t* b2 = reinterpret_cast(&frag_b2); float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200 + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + #else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + #endif + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -template -__device__ inline void ldsm(typename ScalarType::FragA& frag_a, +template +__device__ inline void ldsm(typename MarlinScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); @@ -167,47 +251,54 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +template +__device__ inline void scale(typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -template +template __device__ inline void scale_and_sub( - typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s2 = ScalarType::num2num2(s); - scalar_t2 zp2 = ScalarType::num2num2(zp); + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); } -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t2& frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; +template +__device__ inline void scale4( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s_1, + typename MarlinScalarType::FragS& frag_s_2, + typename MarlinScalarType::FragS& frag_s_3, + typename MarlinScalarType::FragS& frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; @@ -221,12 +312,13 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, } // Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { +template +__device__ inline void scale_float( + float* c, typename MarlinScalarType::FragS& s) { + using scalar_t = typename MarlinScalarType::scalar_t; scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -278,9 +370,10 @@ __device__ inline void wait_negative_and_add(int* lock) { __syncthreads(); } -template ; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; + #endif + + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; extern __shared__ int4 sh[]; - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); } else if constexpr (std::is_same::value) { @@ -355,34 +472,37 @@ __global__ void Marlin( static_assert(s_type == vllm::kFloat16); } - constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || - w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || + is_a_8bit || b_type == vllm::kFE4M3fn || + b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == vllm::kU8); + has_zp && !is_zp_float && !(b_type == vllm::kU8); - scalar_t2 global_scale; + c_scalar_t2 global_scale; constexpr bool has_act_order = group_blocks == 0; - constexpr int pack_factor = 32 / w_type.size_bits(); + constexpr int pack_factor = 32 / b_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); - constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; const int scales_expert_stride = - prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); + prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); const int b_bias_expert_stride = prob_n / 8; // parallel: num valid moe blocks - int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; int parallel = num_tokens_past_padded / moe_block_size; int num_valid_blocks = parallel; if (is_ep) { @@ -395,7 +515,23 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + // we use DP + two-tile SK here + // part1: DP + // part2: two-tile SK + // see https://github.com/vllm-project/vllm/pull/24722 for more details + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x; + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { @@ -407,14 +543,15 @@ __global__ void Marlin( } } - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = + k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; int par_id = 0; int block_id = -1; @@ -422,87 +559,89 @@ __global__ void Marlin( int old_expert_id = 0; int64_t B_expert_off = 0; - int4* sh_block_sorted_ids_int4 = sh; + float* sh_a_s = reinterpret_cast(sh); + int4* sh_block_sorted_ids_int4 = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); int4* sh_rd_block_sorted_ids_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4; int4* sh_block_topk_weights_int4 = sh_rd_block_sorted_ids_int4 + moe_block_size / 4; // sh_block_topk_weights_int4 only need (moe_block_size / 4); // but we pad to align to 256 bytes - int4* sh_new = - sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 2; int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); int32_t* sh_rd_block_sorted_ids = reinterpret_cast(sh_rd_block_sorted_ids_int4); - scalar_t2* sh_block_topk_weights = - reinterpret_cast(sh_block_topk_weights_int4); + c_scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); int32_t block_num_valid_tokens = 0; int32_t locks_off = 0; // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) { - // when parallel * n_tiles >= sms + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms // then there are at most $sms$ conflict tile blocks locks_off = blockIdx.x; } else { locks_off = (iters * blockIdx.x) / k_tiles - 1; } + int prob_m_top_k = prob_m * top_k; // read moe block data given block_id // block_sorted_ids / block_num_valid_tokens / block_topk_weights auto read_moe_block_data = [&](int block_id) { block_num_valid_tokens = moe_block_size; + + cp_async4_pred(sh_block_sorted_ids_int4 + threadIdx.x, + reinterpret_cast(sorted_token_ids_ptr) + + (block_id * moe_block_size / 4 + threadIdx.x), + threadIdx.x < moe_block_size / 4); + + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + if (threadIdx.x >= threads - 32) { + constexpr int size_per_thread = div_ceil(moe_block_size, 32); + int lane_id = threadIdx.x - (threads - 32); + + int local_count = 0; #pragma unroll - for (int i = 0; i < moe_block_size / 4; i++) { - int4 sorted_token_ids_int4 = reinterpret_cast( - sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; - int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); - #pragma unroll - for (int j = 0; j < 4; j++) { - if (sorted_token_ids[j] >= prob_m * top_k) { - block_num_valid_tokens = i * 4 + j; - break; + for (int i = 0; i < size_per_thread; i++) { + int j = lane_id * size_per_thread + i; + if (j < moe_block_size) { + int idx = sh_block_sorted_ids[j]; + if (idx < prob_m_top_k) local_count++; } } - if (block_num_valid_tokens != moe_block_size) break; + + block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); + + if (lane_id == 0) + reinterpret_cast(sh_new)[0] = block_num_valid_tokens; + } + + if (threadIdx.x < moe_block_size) { + int idx = sh_block_sorted_ids[threadIdx.x]; + sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k; + + if (mul_topk_weights) { + idx = idx < prob_m_top_k ? idx : 0; + c_scalar_t2 topk_weight_val = + Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx])); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + topk_weight_val = __hmul2(topk_weight_val, global_scale); + } + sh_block_topk_weights[threadIdx.x] = topk_weight_val; + } } __syncthreads(); - int tid4 = threadIdx.x / 4; - if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { - sh_block_sorted_ids_int4[tid4] = reinterpret_cast( - sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; - #pragma unroll - for (int i = 0; i < 4; i++) - sh_rd_block_sorted_ids[tid4 * 4 + i] = - sh_block_sorted_ids[tid4 * 4 + i] / top_k; - - if (mul_topk_weights) { - #pragma unroll - for (int i = 0; i < 4; i++) { - int idx = tid4 * 4 + i; - if (idx < block_num_valid_tokens) { - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - sh_block_topk_weights[idx] = - __hmul2(global_scale, - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = Dtype::num2num2( - Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); - } - } - } - } - } + block_num_valid_tokens = reinterpret_cast(sh_new)[0]; __syncthreads(); }; @@ -513,9 +652,8 @@ __global__ void Marlin( old_expert_id = expert_id; if (num_invalid_blocks > 0) { - int skip_count = block_id == -1 ? par_id : 0; - block_id++; - for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + int skip_count = par_id; + for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) { expert_id = expert_ids_ptr[i]; if (expert_id != -1) { if (skip_count == 0) { @@ -530,9 +668,9 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = scale2_ptr[expert_id]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + uint16_t val = global_scale_ptr[expert_id]; + global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); } B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); @@ -552,10 +690,11 @@ __global__ void Marlin( // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) { + bool first_init = true; + auto init_part2_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; @@ -573,7 +712,7 @@ __global__ void Marlin( if (col_off > 0) slice_idx--; } } - if (parallel * n_tiles >= gridDim.x) { + if (part2_mn_tiles >= gridDim.x) { if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } @@ -607,25 +746,61 @@ __global__ void Marlin( par_id++; update_next_moe_block_data(); } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + cp_async1_ca_pred(&sh_a_s[threadIdx.x], + &a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]], + threadIdx.x < block_num_valid_tokens); + } }; - update_next_moe_block_data(); - init_slice(true); + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + update_next_moe_block_data(); + if (is_a_8bit) { + __syncthreads(); + cp_async1_ca_pred(&sh_a_s[threadIdx.x], + &a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]], + threadIdx.x < block_num_valid_tokens); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + update_next_moe_block_data(); + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; + int a_gl_stride = prob_k / (is_a_8bit ? 16 : 8); // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile @@ -634,24 +809,25 @@ __global__ void Marlin( constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = + ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_stage = + b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = + 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -664,7 +840,8 @@ __global__ void Marlin( constexpr int act_s_max_num_groups = 32; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; + + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides @@ -679,7 +856,6 @@ __global__ void Marlin( // Global A read index of current thread. int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; - // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -687,17 +863,22 @@ __global__ void Marlin( int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } + + b_gl_rd += B_expert_off + b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; @@ -708,58 +889,54 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / - (w_type == vllm::kFE2M1f ? 2 : 1) + + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; if constexpr (has_zp) { if constexpr (group_blocks == -1) { zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { + } else if constexpr (group_blocks >= thread_k_blocks) { zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; } } auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); } else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; int bias_sh_rd; if constexpr (m_block_size_8) { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; } else { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; } @@ -775,12 +952,16 @@ __global__ void Marlin( if constexpr (has_zp) { if constexpr (is_zp_float) { if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + zp_sh_rd = + 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % tb_n_warps / 2) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } else { zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } } @@ -807,18 +988,13 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; // Shared memory storage for global fetch pipelines. constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; @@ -847,19 +1023,12 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - constexpr int shm_size_used = moe_block_size + - stages * (g_idx_stage + zp_sh_stage) + - sh_s_size + sh_b_red_bias_size; - - // all remaining shared memory is used to cache A (input) - // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` - int sh_a_max_row = - ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; FragS frag_s[2][4]; // No act-order FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order @@ -867,6 +1036,24 @@ __global__ void Marlin( FragZP frag_zp; // Zero-points in fp16 FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + if constexpr (is_a_8bit && group_blocks != -1) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + // Zero accumulators. auto zero_accums = [&]() { #pragma unroll @@ -910,43 +1097,36 @@ __global__ void Marlin( } } }; - // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - bool should_load_a = true; - int max_num_stage_groups = - ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; - max_num_stage_groups = max(max_num_stage_groups, 1); - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, - int pipe_a = 0) { + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - if (should_load_a) { - int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; - int64_t sorted_row = 0; - if (!m_block_size_8 || row < 8) - sorted_row = sh_rd_block_sorted_ids[row]; - int64_t true_idx = - sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], - row < block_num_valid_tokens); - } + for (int i = 0; i < a_sh_wr_iters; i++) { + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = + sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], - B_ptr[i] + j + B_expert_off); - } + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = + b_gl_rd + (i % count) * threads + + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); - B_ptr[i] += b_gl_rd_delta_o; + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); } + b_gl_rd += b_gl_rd_delta_o; + if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; @@ -966,44 +1146,24 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -1037,18 +1197,18 @@ __global__ void Marlin( // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { - int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm( + ldsm( frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -1072,53 +1232,54 @@ __global__ void Marlin( auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = - reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } - } else { + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / tb_n_warps; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = - k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != vllm::kFE2M1f.id()) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + - k % 2]; + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } } } @@ -1139,18 +1300,15 @@ __global__ void Marlin( cur_k = 0; // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + cur_k += k % b_sh_wr_iters; // Determine "position" inside the thread-block (based on warp and // thread-id) auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; + cur_k += warp_row * 16 * b_sh_wr_iters; auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix @@ -1205,18 +1363,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 || is_a_8bit) { #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } - } else if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = @@ -1225,21 +1381,11 @@ __global__ void Marlin( } } else { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / tb_n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1258,29 +1404,18 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + - zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } - } else { + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1291,33 +1426,46 @@ __global__ void Marlin( } }; - auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) { + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) return; int k2 = k % 2; + constexpr int g = + group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; const bool is_new_zp = - ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == 0) || + ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && + (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; int zp_quant_0, zp_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (b_type.size_bits() == 4) { zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = zp_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = frag_qzp[k2][1]; } - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); } } if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { @@ -1327,14 +1475,14 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales( - s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( - s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1345,61 +1493,168 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (b_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; - } else if constexpr (w_type.size_bits() == 4) { + } else if constexpr (b_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); } // Apply scale to frag_b0 - if constexpr (has_act_order) { + if constexpr (has_act_order && !is_a_8bit) { static_assert(group_blocks != -1); - scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && - group_blocks == -1) { + group_blocks == -1 && !is_a_8bit) { int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( + scalar_t2 s2 = Adtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && + !is_a_8bit) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } else if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; + #pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = + reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = + reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma(frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * + scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * + scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || + s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t* s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = + *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } } } } @@ -1413,7 +1668,8 @@ __global__ void Marlin( constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_stride = + b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); @@ -1428,7 +1684,8 @@ __global__ void Marlin( for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { @@ -1437,24 +1694,26 @@ __global__ void Marlin( float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); @@ -1470,13 +1729,13 @@ __global__ void Marlin( // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; if (!is_th_active) { return; } - int c_gl_stride = prob_n / 8; + int c_gl_stride = prob_n / 8 * (is_a_8bit ? 2 : 1); int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr; @@ -1487,7 +1746,7 @@ __global__ void Marlin( } else { c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); } constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; @@ -1506,7 +1765,13 @@ __global__ void Marlin( if (c_idx / c_gl_stride < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + if constexpr (is_a_8bit) { + int2* sh_red_int2 = reinterpret_cast(sh_red); + int2* c_int2 = reinterpret_cast(C); + sh_red_int2[c_sh_wr + c_sh_wr_delta * i] = c_int2[true_idx]; + } else { + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } } } } @@ -1514,29 +1779,37 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_scalar_t* c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = + reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + + delta] += Cdtype::num2float(c_red_f16[j]); } } if (!last) { - int4 c; + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + + delta]); } int c_idx; @@ -1549,7 +1822,12 @@ __global__ void Marlin( if (c_idx / c_gl_stride < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - C[true_idx] = c; + if constexpr (is_a_8bit) { + int2* c_int2 = reinterpret_cast(C); + c_int2[true_idx] = *reinterpret_cast(c_f16); + } else { + C[true_idx] = *reinterpret_cast(c_f16); + } } } } @@ -1563,10 +1841,10 @@ __global__ void Marlin( constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; constexpr int th_size = num_floats * sizeof(float) / 16; int c_cur_offset = locks_off * c_size; @@ -1634,7 +1912,7 @@ __global__ void Marlin( } else { c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); } int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + @@ -1643,49 +1921,49 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + c_scalar_t2 res = + Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && + b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - scalar_t2 tmp_scale = s[0]; + c_scalar_t2 tmp_scale = s[0]; if constexpr (m_block_size_8) { - tmp_scale = Dtype::num2num2( + tmp_scale = Cdtype::num2num2( reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); } res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if (!mul_topk_weights) { res = __hmul2(res, global_scale); } } if (has_bias && last) { - scalar_t2 tmp_bias = b_bias[0]; + c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { - tmp_bias = Dtype::num2num2( + tmp_bias = Cdtype::num2num2( reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); } res = __hadd2(res, tmp_bias); } if constexpr (m_block_size_8) { - ((scalar_t*)sh_red)[idx] = res.x; - ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + ((c_scalar_t*)sh_red)[idx] = res.x; + ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; } else { - ((scalar_t2*)sh_red)[idx] = res; + ((c_scalar_t2*)sh_red)[idx] = res; } }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) { + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], @@ -1723,24 +2001,26 @@ __global__ void Marlin( if (row < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[row]; int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; - scalar_t2 topk_weight_score; + c_scalar_t2 topk_weight_score; if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; if (use_atomic_add && slice_count > 1 || mul_topk_weights) { - scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); + c_scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + c_scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + if (mul_topk_weights) { #pragma unroll - for (int a = 0; a < 4; a++) { - scalar_t2 res = sh_red_half2[a]; - if (mul_topk_weights) { - res = __hmul2(res, topk_weight_score); + for (int a = 0; a < 4; a++) { + sh_red_half2[a] = __hmul2(sh_red_half2[a], topk_weight_score); } + } - if (use_atomic_add && slice_count > 1) { - atomicAdd(&C_half2[a], res); - } else { - C_half2[a] = res; - }; + if (use_atomic_add && slice_count > 1) { + #pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[true_idx] = *reinterpret_cast(sh_red_half2); } } else { C[true_idx] = sh_red[c_sh_rd]; @@ -1774,7 +2054,7 @@ __global__ void Marlin( } } } - fetch_to_shared(i, i, i < slice_iters, i); + fetch_to_shared(i, i, i < slice_iters); } zero_accums(); @@ -1799,73 +2079,100 @@ __global__ void Marlin( // have even length meaning that the next iteration will always start at // index 0. - for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; - stage_group_id++) { #pragma unroll - for (int pipe = 0; pipe < stages;) { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - int idx = - (pipe >= stages && stage_group_id == max_num_stage_groups - 1) - ? (pipe - stages) - : (pipe + stage_group_id * stages); - fetch_to_registers(k + 1, pipe % stages, idx); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) - ? (pipe - 1) - : (pipe + (stage_group_id + 1) * stages - 1); - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages, idx); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd_col += a_gl_rd_delta_o * stages; - - if constexpr (has_act_order) { - slice_k_start += tb_k * stages; - - if (slice_k_start < prob_k) { - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, - last_group_id); - __syncthreads(); - } + + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); } } + slice_iters--; if (slice_iters == 0) { break; } } + a_gl_rd_col += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } + } + // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1883,20 +2190,27 @@ __global__ void Marlin( } if constexpr (!has_act_order && group_blocks == -1 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { cp_async_wait<0>(); __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; if constexpr (m_block_size_8) { int idx = (threadIdx.x / 4) % 2; - scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + c_scalar_t2* frag_s_half2 = + reinterpret_cast(frag_s); #pragma unroll for (int i = 0; i < 8; i++) { - frag_s_half2[i] = Dtype::num2num2( - reinterpret_cast(&frag_s_half2[i])[idx]); + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); } } } @@ -1906,26 +2220,48 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && + b_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); if constexpr (!m_block_size_8) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } @@ -1949,7 +2285,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; - reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + if constexpr (!is_a_8bit) + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; __syncthreads(); } @@ -1958,37 +2295,22 @@ __global__ void Marlin( if (last || use_atomic_add) // only the last block in a slice actually writes the result write_result(last); - int old_slice_row = slice_row; slice_row = 0; - slice_col_par++; - slice_col++; + if (!in_part2) { + slice_col_par += gridDim.x; + } else { + slice_col_par++; + slice_col++; + } is_first_matmul_in_slice = true; init_slice(); - // Should we load A matrix in next slice? - // `slice_col == 0`: when move to a new moe block - // `old_slice_row > 0`: - // when the last slice is not starting from k_index == 0 - // (only happen when it is the first slice of a threadblock) - // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: - // when the required shared memory size is larger than - // the remaining shared memory - if (slice_col == 0 || old_slice_row || - prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { - should_load_a = true; - } else { - should_load_a = false; - } - if (slice_iters) { - a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } + a_gl_rd_col = + a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + b_gl_rd = B_expert_off + b_gl_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading @@ -1998,8 +2320,26 @@ __global__ void Marlin( slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } start_pipes(); } diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 601e2aa6f9913..27b6ffaa67176 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template -__global__ void permute_cols_kernel( - int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, - const int32_t* __restrict__ sorted_token_ids_ptr, - const int32_t* __restrict__ expert_ids_ptr, - const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, - int size_k, int top_k) {}; - -} // namespace marlin - -torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, - std::optional const& b_bias_or_none, torch::Tensor& b_scales, - std::optional const& b_zeros_or_none, - std::optional const& g_idx_or_none, - std::optional const& perm_or_none, torch::Tensor& workspace, - torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, - torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, - int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. template @@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, - int is_zp_float) { + int is_zp_float, bool is_a_8bit) { int pack_factor = 32 / num_bits; // Get B size @@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) - int sh_block_meta_size = tb_m * 4; - int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_block_meta_size = tb_m * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_bias_size = tb_n * 2; @@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, int is_zp_float, - int max_shared_mem) { + int max_shared_mem, bool is_a_8bit) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, } // Check that pipeline fits into cache - int cache_size = get_kernel_cache_size( - th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size + 512 <= max_shared_mem; + int cache_size = + get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit); + return cache_size <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ - } - - // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) - // this is the most common cases - // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) - // FZP: cases for float-zero-point (is_zp_float = true) - // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) - -template -MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, - int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool m_block_size_8, - bool has_act_order, bool has_zp, - int group_blocks, int num_threads, - bool is_zp_float) { - int num_bits = q_type.size_bits(); +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float) { + int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) - - NVFP4_GET_IF(vllm::kFE2M1f) - - BIGGROUP_GET_IF(vllm::kFE4M3fn) - - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) - if (std::is_same::value) { - if (false) { - } - MXFP4_GET_IF(vllm::kFE2M1f) - } +#include "kernel_selector.h" return kernel; } -template -exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, - int prob_n, int prob_k, int thread_m_blocks, - bool m_block_size_8, int num_bits, - int group_size, bool has_act_order, - bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem) { +exec_config_t determine_exec_config( + const vllm::ScalarType& a_type, const vllm::ScalarType& b_type, + const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, + int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, int group_size, bool has_act_order, + bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms, + bool is_a_8bit) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem)) { + is_k_full, has_zp, is_zp_float, max_shared_mem - 512, + is_a_8bit)) { continue; } int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + is_a_8bit); int group_blocks = 0; if (!has_act_order) { group_blocks = group_size == -1 ? -1 : (group_size / 16); } - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, th_config.thread_n / 16, - th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + auto kernel = + get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float); if (kernel == MarlinDefault) continue; - if (thread_m_blocks > 1) { - exec_cfg = {1, th_config}; - break; - } else { - cudaFuncAttributes attr; - cudaFuncGetAttributes(&attr, kernel); - int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; - int allow_count = min(device_max_reg_size / reg_size, - max_shared_mem / (cache_size + 1024)); + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1536)); + if (thread_m_blocks == 1) allow_count = max(min(allow_count, 4), 1); - if (allow_count > count) { - count = allow_count; - exec_cfg = {count, th_config}; - }; + else + allow_count = max(min(allow_count, 2), 1); + + if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) { + allow_count = + max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1); } + + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; } return exec_cfg; } -template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, - void* s, void* s2, void* zp, void* g_idx, void* perm, - void* a_tmp, void* sorted_token_ids, void* expert_ids, - void* num_tokens_past_padded, void* topk_weights, - int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_bias, - bool has_act_order, bool is_k_full, bool has_zp, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { + void* a_s, void* b_s, void* g_s, void* zp, void* g_idx, + void* perm, void* a_tmp, void* sorted_token_ids, + void* expert_ids, void* num_tokens_past_padded, + void* topk_weights, int moe_block_size, int num_experts, + int top_k, bool mul_topk_weights, bool is_ep, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& a_type, vllm::ScalarType const& b_type, + vllm::ScalarType const& c_type, vllm::ScalarType const& s_type, + bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int blocks_per_sm, + bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; - - if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); - } + bool is_a_8bit = a_type.size_bits() == 8; TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } } - int num_bits = q_type.size_bits(); + int num_bits = b_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* bias_ptr = (const int4*)b_bias; - const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* a_s_ptr = (const float*)a_s; + const int4* b_s_ptr = (const int4*)b_s; + const uint16_t* g_s_ptr = (const uint16_t*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + TORCH_CHECK(major_capability * 10 + minor_capability >= 80, + "marlin kernel only support Ampere or newer GPUs."); + if (a_type == vllm::kFE4M3fn) { + TORCH_CHECK(major_capability * 10 + minor_capability >= 89, + "FP8 only support Ada Lovelace or newer GPUs."); + TORCH_CHECK( + major_capability * 10 + minor_capability == 89 || + major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + // Set thread config exec_config_t exec_cfg; thread_config_t thread_tfg; if (thread_k != -1 && thread_n != -1) { - thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; - exec_cfg = exec_config_t{1, thread_tfg}; + thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64}; + if (blocks_per_sm == -1) blocks_per_sm = 1; + exec_cfg = exec_config_t{blocks_per_sm, thread_tfg}; TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); } else { // Auto config - exec_cfg = determine_exec_config( - q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem); + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts, + top_k, thread_m_blocks, m_block_size_8, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms, + is_a_8bit); thread_tfg = exec_cfg.tb_cfg; } @@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; - TORCH_CHECK( - is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem), - "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - ", thread_k = ", thread_tfg.thread_k, - ", thread_n = ", thread_tfg.thread_n, - ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", - prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", has_act_order = ", has_act_order, - ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, - ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); + TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem, is_a_8bit), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem = ", max_shared_mem); - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, - has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + int sh_cache_size = + get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit); + + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce); // clang-format on } } // namespace MARLIN_NAMESPACE_NAME torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& a_scales_or_none, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); + bool is_zp_float, int64_t thread_k, int64_t thread_n, + int64_t blocks_per_sm) { + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + auto c_dtype = a.dtype(); + if (a.scalar_type() == at::ScalarType::Half) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + c_dtype = b_scales.dtype(); + if (b_scales.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + + TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4"); + torch::Tensor c = c_or_none.value(); + c_dtype = c.dtype(); + + if (c.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (c.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + TORCH_CHECK(false, "unsupported c dtype"); + } + } + + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + a_type_id = vllm::kFE4M3fn.id(); + } else if (a.scalar_type() == at::ScalarType::Char) { + a_type_id = vllm::kS8.id(); + } else { + TORCH_CHECK(false, "unsupported `a` scalar_type"); + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + s_type_id = vllm::kFE4M3fn.id(); + } else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); + } else { + TORCH_CHECK(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); + int num_experts = b_q_weight.size(0); if (moe_block_size != 8) { TORCH_CHECK(moe_block_size % 16 == 0, @@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; + torch::Tensor a_scales; + auto options = torch::TensorOptions().dtype(c_dtype).device(a.device()); + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + + if (a_scales_or_none.has_value()) { + a_scales = a_scales_or_none.value(); + TORCH_CHECK(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + a_scales = torch::empty({0}, options_fp32); + TORCH_CHECK(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + // sms: number of SMs to use for the kernel int sms = -1; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; if (c_or_none.has_value()) { c = c_or_none.value(); @@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm( // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - auto options_fp32 = - torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce && !use_atomic_add) { // max num of threadblocks is sms * 4 long max_c_tmp_size = min( @@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm( bool has_zp = b_zeros.size(-1) > 0; if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); } else { - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " - "float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); } if (has_zp && is_zp_float) { @@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - MARLIN_NAMESPACE_NAME::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), - sorted_token_ids.data_ptr(), expert_ids.data_ptr(), - num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), - moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - - MARLIN_NAMESPACE_NAME::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - sorted_token_ids.data_ptr(), expert_ids.data_ptr(), - num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), - moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, - "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, + "scalar type of a_scales must be float"); + TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), + "scalar type of global_scale must be the same with c"); + if (a_type.size_bits() == 16) { + TORCH_CHECK( + a.scalar_type() == c.scalar_type(), + "scalar type of a must be the same with c for 16 bit activation"); } + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, num_experts, top_k, + mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(), + a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce, + is_zp_float); + return c; } -#endif - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index bd95ade40a083..e0a8280722f3c 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none," - "Tensor! b_scales, Tensor? global_scale, Tensor? " + "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? " "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! topk_weights, int moe_block_size, int top_k, " - "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "bool mul_topk_weights, bool is_ep, int b_type_id," "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," - "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + "bool use_fp32_reduce, bool is_zp_float," + "int thread_k, int thread_n, int blocks_per_sm) -> Tensor"); + m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index 03bd5964a7fc4..e306ff02605b9 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { #pragma unroll for (int k_idx = 0; k_idx < 2; ++k_idx) { - FType low16 = - ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2]); - FType high16 = - ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); + FType low16 = MarlinScalarType2::float2num( + C_frag[m_idx][n_idx][k_idx * 2]); + FType high16 = MarlinScalarType2::float2num( + C_frag[m_idx][n_idx][k_idx * 2 + 1]); uint32_t tmp = (reinterpret_cast(low16) & 0xffff) | (reinterpret_cast(high16) << 16); int sts_offset = diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh index 831413016538e..14a61ad8fd880 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -8,7 +8,7 @@ #include #include #include "../gptq_marlin/marlin_dtypes.cuh" -using marlin::ScalarType; +using marlin::MarlinScalarType2; namespace allspark { @@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; for (int i = 0; i < n_mat; ++i) { - sum += ScalarType::num2float(C_split[idx + i * matrix_size]); + sum += MarlinScalarType2::num2float(C_split[idx + i * matrix_size]); } - C[idx] = ScalarType::float2num(sum); + C[idx] = MarlinScalarType2::float2num(sum); } template diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/gptq_marlin/.gitignore index 77088552b85b4..ba805f9250ece 100644 --- a/csrc/quantization/gptq_marlin/.gitignore +++ b/csrc/quantization/gptq_marlin/.gitignore @@ -1 +1,2 @@ -kernel_*.cu \ No newline at end of file +sm*_kernel_*.cu +kernel_selector.h diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index e607107b3e77c..307bae6738ecf 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -4,14 +4,16 @@ namespace marlin { -template +template __global__ void awq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; @@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel( extern __shared__ int4 sh[]; - constexpr int tile_n_ints = tile_n_size / pack_factor; + constexpr int tile_n_ints = target_tile_n_size / pack_factor; constexpr int stage_n_threads = tile_n_ints / 4; - constexpr int stage_k_threads = tile_k_size; + constexpr int stage_k_threads = target_tile_k_size; constexpr int stage_size = stage_k_threads * stage_n_threads; auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { @@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel( return; } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * target_tile_n_size; int first_n_packed = first_n / pack_factor; int4* sh_ptr = sh + stage_size * pipe; @@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel( auto k_id = threadIdx.x / stage_n_threads; auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * target_tile_k_size; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast( @@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel( } int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; int cur_n_packed = cur_n / pack_factor; int cur_n_pos = cur_n % pack_factor; @@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel( uint32_t vals[8]; #pragma unroll for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; + if constexpr (is_a_8bit) { + int cur_elem = tc_row + i; - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + - sh_stride * cur_elem]; + int packed_src_0 = + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + + sh_stride * cur_elem]; + int packed_src_1 = + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + + sh_stride * (cur_elem + 16)]; - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } else { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = + sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + constexpr int tile_size = + target_tile_k_size * target_tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; uint32_t res = 0; #pragma unroll @@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel( uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; @@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel( } // namespace marlin -#define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - marlin::awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::awq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, out_ptr, size_k, size_n); \ +#define CALL_IF(NUM_BITS, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ } torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits) { + int64_t size_n, int64_t num_bits, + bool is_a_8bit) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", marlin::tile_k_size); @@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, if (false) { } - CALL_IF(4) - CALL_IF(8) + CALL_IF(4, false) + CALL_IF(8, false) + CALL_IF(4, true) + CALL_IF(8, true) else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", is_a_8bit = ", is_a_8bit); } return out; diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index e8b0c302b2021..26b8d40368aa9 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,6 +470,50 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>( + int q, __nv_fp8x4_e4m3* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4; + constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70707070; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + + // Note1: reverse indexing is intentional because weights are permuted + // Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of + // `frag_b[1]` to fit the layout of tensorcore + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, int32_t* frag_b) { + constexpr int repeated_zp = 0x08080808; + constexpr int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>( + int q, __nv_fp8x4_e4m3* frag_b) { + int s = q & 0x08080808; + int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + q >>= 4; + s = q & 0x08080808; + int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); +} + template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); @@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales( // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); +}; + +// subtract zero point in quanted format and then dequant +template +__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp); + +template <> +__device__ inline void sub_zp_and_dequant( + int q, int32_t* frag_b, int zp) { + // INT4 with zp -> INT8 + // see https://github.com/vllm-project/vllm/pull/24722 + int repeated_zp = 0x01010101 * zp; + int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(), + true>(int q, __nv_fp8x4_e4m3* frag_b, + int zp) { + // INT4 with zp -> FP8 + // see https://github.com/vllm-project/vllm/pull/24722 + uint32_t u_q = *reinterpret_cast(&q); + uint32_t u_zp = *reinterpret_cast(&zp); + uint32_t u_zp1 = u_zp + 1; + uint32_t repeated_zp = 0x01010101 * u_zp; + + uint32_t q0, s; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + u_q >>= 4; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); } #endif diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 42d3b456096ee..27ef7271ba41c 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -4,141 +4,292 @@ import glob import itertools import os import subprocess +import sys import jinja2 -FILE_HEAD = """ -// auto generated by generate.py -// clang-format off +ARCHS = [] +SUPPORT_FP8 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py +// clang-format off +""".lstrip() + +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { -""".strip() +""" +) TEMPLATE = ( "template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " + "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" + "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) -# int8 with zero point case (vllm::kU8) is also supported, -# we don't add it to reduce wheel size. -SCALAR_TYPES = [ - "vllm::kU4", - "vllm::kU4B8", - "vllm::kU8B128", - "vllm::kFE4M3fn", - "vllm::kFE2M1f", -] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] -# group_blocks: -# = 0 : act order case -# = -1 : channelwise quantization -# > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] -DTYPES = ["fp16", "bf16"] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # HQQ + { + "a_type": ["kFloat16"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [4], + "is_zp_float": True, + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # GPTQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + result_dict = {} + + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + is_zp_float = quant_config.get("is_zp_float", False) + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": + continue + if "16" in a_type and "16" in c_type and a_type != c_type: + continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue + + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": "pipe_stages", + "group_blocks": group_blocks, + "is_zp_float": "true" if is_zp_float else "false", + } + + result_dict[(a_type, b_type, c_type)].append(config) + + kernel_selector_str = FILE_HEAD_COMMENT + + for (a_type, b_type, c_type), config_list in result_dict.items(): all_template_str_list = [] + for config in config_list: + s_type = config["s_type"] + template_str = jinja2.Template(TEMPLATE).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + all_template_str_list.append(template_str) - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS - ): - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", - "vllm::kU8B128", - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: - continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: - continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: - continue + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" - - is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: - # HQQ (is_zp_float = true) only supports - # 4bit quantization and fp16 - is_zp_float_list.append(True) - - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" - - for is_zp_float in is_zp_float_list: - template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=is_zp_float, + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, ) - - all_template_str_list.append(template_str) + + "\n" + ) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " TORCH_CHECK(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + if __name__ == "__main__": remove_old_kernels() diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index cc30abcf00800..28ff06559a98a 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm( std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size + 512 <= max_shared_mem; + return cache_size <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ - } - - // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) - // this is the most common cases - // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) - // FZP: cases for float-zero-point (is_zp_float = true) - // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 4, 8, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - -template -MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, - int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool m_block_size_8, - bool has_act_order, bool has_zp, - int group_blocks, int num_threads, - bool is_zp_float) { - int num_bits = q_type.size_bits(); +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float) { + int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) - - NVFP4_GET_IF(vllm::kFE2M1f) - - BIGGROUP_GET_IF(vllm::kFE4M3fn) - - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) - - if (std::is_same::value) { - if (false) { - } - FZP_GET_IF(vllm::kU4) - } - if (std::is_same::value) { - if (false) { - } - MXFP4_GET_IF(vllm::kFE2M1f) - } + #include "kernel_selector.h" return kernel; } -template -exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, - int prob_n, int prob_k, int thread_m_blocks, - bool m_block_size_8, int num_bits, - int group_size, bool has_act_order, - bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem, - int sms) { +exec_config_t determine_exec_config( + const vllm::ScalarType& a_type, const vllm::ScalarType& b_type, + const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8, + int num_bits, int group_size, bool has_act_order, bool is_k_full, + bool has_zp, bool is_zp_float, int max_shared_mem, int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem)) { + is_zp_float, max_shared_mem - 512)) { continue; } @@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, group_blocks = group_size == -1 ? -1 : group_size / 16; } - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, th_config.thread_n / 16, - th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + auto kernel = + get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float); if (kernel == MarlinDefault) continue; @@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, return exec_cfg; } -template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, - void* s, void* s2, void* zp, void* g_idx, void* perm, - void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, - void* workspace, vllm::ScalarType const& q_type, bool has_bias, + void* a_s, void* b_s, void* g_s, void* zp, void* g_idx, + void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, + int lda, void* workspace, vllm::ScalarType const& a_type, + vllm::ScalarType const& b_type, vllm::ScalarType const& c_type, + vllm::ScalarType const& s_type, bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { - if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); - } - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } } - int num_bits = q_type.size_bits(); + int num_bits = b_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; - const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* a_s_ptr = (const float*)a_s; + const int4* b_s_ptr = (const int4*)b_s; + const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; - int* locks = (int*)workspace; if (has_act_order) { @@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + TORCH_CHECK(major_capability * 10 + minor_capability >= 80, + "marlin kernel only support Ampere or newer GPUs."); + if (a_type == vllm::kFE4M3fn) { + TORCH_CHECK( + major_capability * 10 + minor_capability == 89 || + major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + int max_par = 16; if (prob_n <= 4096) max_par = 16 * 8; int max_shared_mem_new = max_shared_mem; @@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int thread_n = thread_n_init; int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); - int m_block_size_8 = prob_m_split <= 8; + int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16; // Set thread config exec_config_t exec_cfg; @@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, " is not divisible by thread_k = ", thread_k); } else { // Auto config - exec_cfg = determine_exec_config( - q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem, sms); + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k, + thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem, sms); thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_n != -1) { + if (prob_n / thread_tfg.thread_n * + div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <= + sms) { + if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split, + prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem_new)) { + thread_tfg = {128, 64, 128}; + exec_cfg = {1, thread_tfg}; + } + } + } + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { max_thread_m_blocks--; continue; @@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", max_shared_mem_new = ", max_shared_mem_new); - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, - m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, - is_zp_float); + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, num_groups, prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on - A_ptr += prob_m_split * (lda / 8); + bool is_a_8bit = a_type.size_bits() == 8; + A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8)); + a_s_ptr += prob_m_split; C_ptr += prob_m_split * (prob_n / 8); rest_m -= prob_m_split; } @@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& a_scales_or_none, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + auto c_dtype = a.dtype(); + if (a.scalar_type() == at::ScalarType::Half) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + c_dtype = b_scales.dtype(); + if (b_scales.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + + TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4"); + torch::Tensor c = c_or_none.value(); + c_dtype = c.dtype(); + + if (c.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (c.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + TORCH_CHECK(false, "unsupported c dtype"); + } + } + + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + a_type_id = vllm::kFE4M3fn.id(); + } else if (a.scalar_type() == at::ScalarType::Char) { + a_type_id = vllm::kS8.id(); + } else { + TORCH_CHECK(false, "unsupported `a` scalar_type"); + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + s_type_id = vllm::kFE4M3fn.id(); + } else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); + } else { + TORCH_CHECK(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm( TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + torch::Tensor a_scales; + auto options = torch::TensorOptions().dtype(c_dtype).device(a.device()); + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + + if (a_scales_or_none.has_value()) { + a_scales = a_scales_or_none.value(); + TORCH_CHECK(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + a_scales = torch::empty({0}, options_fp32); + TORCH_CHECK(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; @@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm( // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; if (c_or_none.has_value()) { c = c_or_none.value(); @@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm( // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - auto options_fp32 = - torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce) { int max_m_block_size = (size_m + 16 - 1) / 16 * 16; max_m_block_size = min(max_m_block_size, 64); @@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm( bool has_zp = b_zeros.size(-1) > 0; if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); } else { - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " - "float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); } if (has_zp && is_zp_float) { @@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - marlin::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, - is_k_full, has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - - marlin::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, - has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, + "scalar type of a_scales must be float"); + TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), + "scalar type of global_scale must be the same with c"); + if (a_type.size_bits() == 16) { + TORCH_CHECK( + a.scalar_type() == c.scalar_type(), + "scalar type of a must be the same with c for 16 bit activation"); } + marlin::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), + workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + return c; } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index ad80d51ece94e..796e6c5359da1 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -4,15 +4,18 @@ namespace marlin { -template +template __global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; @@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel( extern __shared__ int4 sh[]; - constexpr int perm_size = tile_k_size / 4; + constexpr int perm_size = target_tile_k_size / 4; int4* sh_perm_ptr = sh; int4* sh_pipe_ptr = sh_perm_ptr; @@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel( sh_pipe_ptr += perm_size; } - constexpr int tile_ints = tile_k_size / pack_factor; + constexpr int tile_ints = target_tile_k_size / pack_factor; - constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_n_threads = target_tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; + int first_k_int4 = (k_tile_id * target_tile_k_size) / 4; int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); @@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel( return; } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * target_tile_n_size; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; @@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel( auto k_id = threadIdx.x / stage_n_threads; auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * target_tile_k_size; int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], @@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel( } int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; - constexpr int sh_stride = 64; + constexpr int sh_stride = target_tile_n_size; constexpr uint32_t mask = (1 << num_bits) - 1; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; @@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel( uint32_t vals[8]; if constexpr (has_perm) { + static_assert(!is_a_8bit); for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; @@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel( #pragma unroll for (int i = 0; i < tile_ints; i++) { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + if constexpr (is_a_8bit) { + b1_vals[i] = + sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8]; + } else { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } } #pragma unroll for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; + int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]); int cur_int = cur_elem / pack_factor; int cur_pos = cur_elem % pack_factor; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + if constexpr (is_a_8bit) + vals[4 + i] = + (b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask; + else + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + constexpr int tile_size = + target_tile_k_size * target_tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; uint32_t res = 0; #pragma unroll @@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel( uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; @@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel( } // namespace marlin -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ +#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \ + is_a_8bit == IS_A_8BIT) { \ cudaFuncSetAttribute( \ marlin::gptq_marlin_repack_kernel, \ + HAS_PERM, IS_A_8BIT>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ marlin::gptq_marlin_repack_kernel \ + HAS_PERM, IS_A_8BIT> \ <<>>( \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, - int64_t num_bits) { + int64_t num_bits, bool is_a_8bit) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", marlin::tile_k_size); @@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, if (false) { } - CALL_IF(4, false) - CALL_IF(4, true) - CALL_IF(8, false) - CALL_IF(8, true) + CALL_IF(4, false, false) + CALL_IF(4, true, false) + CALL_IF(8, false, false) + CALL_IF(8, true, false) + + CALL_IF(4, false, true) + CALL_IF(8, false, true) + else { TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, - ", has_perm = ", has_perm); + ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit); } return out; diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index bb454f6aff22a..b3b79c8aec452 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -11,17 +11,19 @@ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ int max_shared_mem namespace MARLIN_NAMESPACE_NAME { -template (__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index cc16054814342..a4807a6887f81 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -2,8 +2,10 @@ #ifndef _data_types_cuh #define _data_types_cuh #include "marlin.cuh" +#include "core/scalar_type.hpp" #include #include +#include #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin @@ -11,14 +13,16 @@ namespace MARLIN_NAMESPACE_NAME { -template -class ScalarType {}; +template +class MarlinScalarType {}; template <> -class ScalarType { +class MarlinScalarType { public: using scalar_t = half; using scalar_t2 = half2; + using scalar_t4 = half2; + using scalar_32bit_t = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: @@ -27,6 +31,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; using FragZP = Vec; static __device__ float inline num2float(const half x) { @@ -44,18 +49,25 @@ class ScalarType { static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } }; template <> -class ScalarType { +class MarlinScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; + using scalar_t4 = nv_bfloat162; + using scalar_32bit_t = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; using FragZP = Vec; #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 @@ -75,9 +87,63 @@ class ScalarType { static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } #endif }; +template <> +class MarlinScalarType { + public: + using scalar_t = __nv_fp8_e4m3; + using scalar_t2 = __nv_fp8x2_e4m3; + using scalar_t4 = __nv_fp8x4_e4m3; + using scalar_32bit_t = __nv_fp8x4_e4m3; + + using FragA = Vec<__nv_fp8x4_e4m3, 4>; + using FragB = Vec<__nv_fp8x4_e4m3, 2>; + using FragC = Vec; + using FragZP = Vec<__nv_fp8x2_e4m3, 4>; + + static __host__ __device__ + float2 inline num22float2(const __nv_fp8x2_e4m3 x) { + return (float2)x; + } +}; + +template <> +class MarlinScalarType { + public: + using scalar_t = int8_t; + using scalar_t2 = int16_t; + using scalar_t4 = int32_t; + using scalar_32bit_t = int32_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragZP = Vec; +}; + +template +class MarlinScalarType2 {}; + +template <> +class MarlinScalarType2 : public MarlinScalarType {}; + +template <> +class MarlinScalarType2 + : public MarlinScalarType {}; + +template <> +class MarlinScalarType2<__nv_fp8_e4m3> + : public MarlinScalarType {}; + +template <> +class MarlinScalarType2 : public MarlinScalarType {}; + } // namespace MARLIN_NAMESPACE_NAME #endif diff --git a/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu b/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu new file mode 100644 index 0000000000000..7d4c97fb57ed4 --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu @@ -0,0 +1,106 @@ + + +#include "marlin.cuh" + +#include "core/registration.h" + +// for only non-zp format (like gptq) +__global__ void marlin_int4_fp8_preprocess_kernel_without_zp( + // qweight: (size_k * size_n // 8,) + const int32_t* __restrict__ qweight, + // output: same shape with qweight + int32_t* __restrict__ output) { + int32_t val = qweight[blockIdx.x * 32 + threadIdx.x]; + int32_t new_val = 0; + +#pragma unroll + for (int32_t i = 0; i < 8; i++) { + int32_t single_val = val & 0xF; + single_val = single_val >= 8 ? single_val - 8 : 15 - single_val; + new_val |= single_val << (i * 4); + val >>= 4; + } + + output[blockIdx.x * 32 + threadIdx.x] = new_val; +} + +// for awq format only (with zp and with awq weight layout) +__global__ void marlin_int4_fp8_preprocess_kernel_awq( + // AWQ qweight: (size_k, size_n // 8) + const int32_t* __restrict__ qweight, + // output: same shape with qweight + int32_t* __restrict__ output, + // AWQ zeros: (size_k // group_size, size_n // 8) + const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k, + int32_t group_size) { + int32_t val = + qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y]; + int32_t zero = + qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 + + blockIdx.y]; + int32_t new_val = 0; + +#pragma unroll + for (int32_t i = 0; i < 8; i++) { + int32_t single_val = val & 0xF; + int32_t single_zero = zero & 0xF; + + single_val = + single_val >= single_zero ? single_val - single_zero : 15 - single_val; + new_val |= single_val << (i * 4); + val >>= 4; + zero >>= 4; + } + + output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val; +} + +torch::Tensor marlin_int4_fp8_preprocess( + torch::Tensor& qweight, std::optional qzeros_or_none, + bool inplace) { + TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU"); + TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int, + "qweight.dtype != torch.int32"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); + + torch::Tensor output = inplace ? qweight : torch::empty_like(qweight); + + if (!qzeros_or_none.has_value()) { + TORCH_CHECK(qweight.numel() * 8 % 256 == 0, + "qweight.numel() * 8 % 256 != 0"); + + int blocks = qweight.numel() * 8 / 256; + marlin_int4_fp8_preprocess_kernel_without_zp<<>>( + (const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr()); + } else { + int32_t size_k = qweight.size(0); + int32_t size_n = qweight.size(1) * 8; + torch::Tensor qzeros = qzeros_or_none.value(); + + TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0"); + TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU"); + TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int, + "qweight.dtype != torch.int32"); + TORCH_CHECK(device_of(qweight) == device_of(qzeros), + "qzeros is not on the same device with qweight"); + + int32_t group_size = qweight.size(0) / qzeros.size(0); + TORCH_CHECK(qweight.size(1) == qzeros.size(1), + "qweight.size(1) != qzeros.size(1)"); + TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0, + "qweight.size(0) % qzeros.size(0) != 0"); + TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0"); + + dim3 blocks(size_k / 32, size_n / 8); + marlin_int4_fp8_preprocess_kernel_awq<<>>( + (const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(), + (const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size); + } + + return output; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess); +} diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index bfb0a3668f527..22bb71e482ce8 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } -template +template __device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); const uint32_t* b2 = reinterpret_cast(&frag_b2); float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -template -__device__ inline void ldsm(typename ScalarType::FragA& frag_a, +template +__device__ inline void ldsm(typename MarlinScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); @@ -159,47 +233,54 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +template +__device__ inline void scale(typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -template +template __device__ inline void scale_and_sub( - typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s2 = ScalarType::num2num2(s); - scalar_t2 zp2 = ScalarType::num2num2(zp); + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); } -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t2& frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; +template +__device__ inline void scale4( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s_1, + typename MarlinScalarType::FragS& frag_s_2, + typename MarlinScalarType::FragS& frag_s_3, + typename MarlinScalarType::FragS& frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; @@ -213,12 +294,13 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, } // Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { +template +__device__ inline void scale_float( + float* c, typename MarlinScalarType::FragS& s) { + using scalar_t = typename MarlinScalarType::scalar_t; scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -270,9 +352,10 @@ __device__ inline void wait_negative_and_add(int* lock) { __syncthreads(); } -template __global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ A0, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C0, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ b_bias_ptr, - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 - // only) - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + // float scales of input matrix, only used when is_a_8bit == true. + // shape (m,) + const float* __restrict__ a_scales_ptr, + // fp16 quantization scales. shape (k/groupsize, n) + const int4* __restrict__ scales_ptr, + // fp16 global scale (for nvfp4// only) + const uint16_t* __restrict__ global_scale_ptr, + // 4bit packed zero-points of shape + // (k/groupsize, n/pack_factor) + const int4* __restrict__ zp_ptr, + // int32 group indices of shape k + const int* __restrict__ g_idx, int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -321,17 +409,35 @@ __global__ void Marlin( // ensures good utilization of all SMs for many kinds of shape and GPU // configurations, while requiring as few slow global cross-threadblock // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; + #endif + + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + const int4* A = A0; + int4* C = C0; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; + + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); } else if constexpr (std::is_same::value) { @@ -340,27 +446,35 @@ __global__ void Marlin( static_assert(s_type == vllm::kFloat16); } - constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || - w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || + is_a_8bit || b_type == vllm::kFE4M3fn || + b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == vllm::kU8); + has_zp && !is_zp_float && !(b_type == vllm::kU8); - scalar_t2 global_scale; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - // NVFP4 format requires global scale - uint16_t val = scale2_ptr[0]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + c_scalar_t2 global_scale; + + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + uint16_t val = global_scale_ptr[0]; + global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); } constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); - constexpr int pack_factor = 32 / w_type.size_bits(); + extern __shared__ int4 sh[]; + float* sh_a_s = reinterpret_cast(sh); + int4* sh_new = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); + constexpr int pack_factor = 32 / b_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); // For larger GEMMs we run multiple batchsize 64 versions in parallel for a @@ -373,7 +487,19 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x; + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { @@ -385,28 +511,21 @@ __global__ void Marlin( } } - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = + k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; int par_id = 0; int locks_off = 0; - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) { - // when parallel * n_tiles >= sms + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms // then there are at most $sms$ conflict tile blocks locks_off = blockIdx.x; } else { @@ -415,10 +534,11 @@ __global__ void Marlin( // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) { + bool first_init = true; + auto init_part2_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; @@ -436,7 +556,7 @@ __global__ void Marlin( if (col_off > 0) slice_idx--; } } - if (parallel * n_tiles >= gridDim.x) { + if (part2_mn_tiles >= gridDim.x) { if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } @@ -466,28 +586,68 @@ __global__ void Marlin( } if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * lda / 8; + A += 16 * thread_m_blocks * lda / (is_a_8bit ? 16 : 8); C += 16 * thread_m_blocks * prob_n / 8; slice_col = 0; par_id++; } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } }; - init_slice(true); + + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + if (is_a_8bit) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = lda / 8; + int a_gl_stride = lda / (is_a_8bit ? 16 : 8); // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile @@ -496,24 +656,25 @@ __global__ void Marlin( constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = + ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_stage = + b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = + 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -527,7 +688,7 @@ __global__ void Marlin( int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides @@ -550,17 +711,22 @@ __global__ void Marlin( int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; + + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; @@ -571,58 +737,54 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / - (w_type == vllm::kFE2M1f ? 2 : 1) + + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; if constexpr (has_zp) { if constexpr (group_blocks == -1) { zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { + } else if constexpr (group_blocks >= thread_k_blocks) { zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; } } auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); } else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; int bias_sh_rd; if constexpr (m_block_size_8) { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; } else { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; } @@ -638,12 +800,16 @@ __global__ void Marlin( if constexpr (has_zp) { if constexpr (is_zp_float) { if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + zp_sh_rd = + 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % tb_n_warps / 2) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } else { zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } } @@ -678,26 +844,19 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; constexpr int sh_b_size = stages * b_sh_stage; - int4* sh_b = sh; - int4* sh_red = sh; - + int4* sh_b = sh_new; + int4* sh_red = sh_new; constexpr int sh_size_b_red_min = (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); constexpr int sh_size_b_red_max = @@ -708,8 +867,8 @@ __global__ void Marlin( ? sh_size_b_red_max : (sh_size_b_red_min + sh_bias_size); - int4* sh_bias = sh + sh_size_b_red_min; - int4* sh_g_idx = sh + sh_b_red_bias_size; + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -723,7 +882,8 @@ __global__ void Marlin( // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; FragS frag_s[2][4]; // No act-order FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order @@ -731,6 +891,24 @@ __global__ void Marlin( FragZP frag_zp; // Zero-points in fp16 FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + if constexpr (is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + // Zero accumulators. auto zero_accums = [&]() { #pragma unroll @@ -788,15 +966,17 @@ __global__ void Marlin( } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = + b_gl_rd + (i % count) * threads + + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); - B_ptr[i] += b_gl_rd_delta_o; + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); } + b_gl_rd += b_gl_rd_delta_o; + if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; @@ -816,44 +996,24 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -891,14 +1051,14 @@ __global__ void Marlin( int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm( + ldsm( frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -922,53 +1082,54 @@ __global__ void Marlin( auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = - reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } - } else { + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / tb_n_warps; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = - k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != vllm::kFE2M1f.id()) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + - k % 2]; + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } } } @@ -989,18 +1150,15 @@ __global__ void Marlin( cur_k = 0; // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + cur_k += k % b_sh_wr_iters; // Determine "position" inside the thread-block (based on warp and // thread-id) auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; + cur_k += warp_row * 16 * b_sh_wr_iters; auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix @@ -1055,18 +1213,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 || is_a_8bit) { #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } - } else if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = @@ -1075,21 +1231,11 @@ __global__ void Marlin( } } else { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / tb_n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1108,29 +1254,18 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + - zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } - } else { + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1141,33 +1276,46 @@ __global__ void Marlin( } }; - auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) { + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) return; int k2 = k % 2; + constexpr int g = + group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; const bool is_new_zp = - ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == 0) || + ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && + (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; int zp_quant_0, zp_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (b_type.size_bits() == 4) { zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = zp_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = frag_qzp[k2][1]; } - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); } } if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { @@ -1177,14 +1325,14 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales( - s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( - s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1195,61 +1343,168 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (b_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; - } else if constexpr (w_type.size_bits() == 4) { + } else if constexpr (b_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); } // Apply scale to frag_b0 - if constexpr (has_act_order) { + if constexpr (has_act_order && !is_a_8bit) { static_assert(group_blocks != -1); - scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && - group_blocks == -1) { + group_blocks == -1 && !is_a_8bit) { int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( + scalar_t2 s2 = Adtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && + !is_a_8bit) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } else if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; + #pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = + reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = + reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma(frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * + scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * + scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || + s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t* s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = + *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } } } } @@ -1263,7 +1518,8 @@ __global__ void Marlin( constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_stride = + b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); @@ -1278,7 +1534,8 @@ __global__ void Marlin( for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { @@ -1287,24 +1544,26 @@ __global__ void Marlin( float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); @@ -1320,10 +1579,10 @@ __global__ void Marlin( // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_o = 8 * c_gl_stride * (is_a_8bit ? 2 : 1); int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr; if constexpr (m_block_size_8) { @@ -1331,9 +1590,9 @@ __global__ void Marlin( 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; c_gl_wr += (2 * thread_n_blocks) * slice_col; } else { - c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) * (is_a_8bit ? 2 : 1) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); } constexpr int c_sh_wr_delta = active_threads; auto c_sh_wr = threadIdx.x; @@ -1351,6 +1610,14 @@ __global__ void Marlin( &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], (threadIdx.x % 4) * 2 + i < prob_m); + } else if constexpr (is_a_8bit) { + int2* sh_red_int2 = reinterpret_cast(sh_red); + int2* c_int2 = reinterpret_cast(C); + cp_async2_ca_pred( + &sh_red_int2[c_sh_wr + c_sh_wr_delta * i], + &c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } else { cp_async4_pred( &sh_red[c_sh_wr + c_sh_wr_delta * i], @@ -1370,36 +1637,51 @@ __global__ void Marlin( (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); if (mask) { if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_scalar_t* c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = + reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + + (i % 4) + delta] += Cdtype::num2float(c_red_f16[j]); } } if (!last) { - int4 c; + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + + (i % 4) + delta]); } - if constexpr (m_block_size_8) + if constexpr (m_block_size_8) { C[c_gl_wr + i * c_gl_stride + - (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; - else + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = + *reinterpret_cast(c_f16); + } else if constexpr (is_a_8bit) { + int2* c_int2 = reinterpret_cast(C); + c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)] = + *reinterpret_cast(c_f16); + } else { C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)] = c; + c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast(c_f16); + } } } } @@ -1414,10 +1696,10 @@ __global__ void Marlin( constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; constexpr int th_size = num_floats * sizeof(float) / 16; int c_cur_offset = locks_off * c_size; @@ -1471,7 +1753,7 @@ __global__ void Marlin( } else { c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); } int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + @@ -1481,47 +1763,47 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + c_scalar_t2 res = + Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && + b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - scalar_t2 tmp_scale = s[0]; + c_scalar_t2 tmp_scale = s[0]; if constexpr (m_block_size_8) { - tmp_scale = Dtype::num2num2( + tmp_scale = Cdtype::num2num2( reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); } res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { res = __hmul2(res, global_scale); } if (has_bias && last) { - scalar_t2 tmp_bias = b_bias[0]; + c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { - tmp_bias = Dtype::num2num2( + tmp_bias = Cdtype::num2num2( reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); } res = __hadd2(res, tmp_bias); } if constexpr (m_block_size_8) { - ((scalar_t*)sh_red)[idx] = res.x; - ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + ((c_scalar_t*)sh_red)[idx] = res.x; + ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; } else { - ((scalar_t2*)sh_red)[idx] = res; + ((c_scalar_t2*)sh_red)[idx] = res; } }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) { + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], @@ -1557,9 +1839,9 @@ __global__ void Marlin( i++) { if (c_gl_wr < c_gl_wr_end) { if (use_atomic_add && slice_count > 1) { - scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); + c_scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + c_scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); #pragma unroll for (int a = 0; a < 4; a++) { atomicAdd(&C_half2[a], sh_red_half2[a]); @@ -1635,7 +1917,13 @@ __global__ void Marlin( wait_for_stage(); init_same_group(pipe % stages); } - matmul(k); + + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); + } } slice_iters--; if (slice_iters == 0) { @@ -1668,13 +1956,47 @@ __global__ void Marlin( // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1692,20 +2014,27 @@ __global__ void Marlin( } if constexpr (!has_act_order && group_blocks == -1 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { cp_async_wait<0>(); __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; if constexpr (m_block_size_8) { int idx = (threadIdx.x / 4) % 2; - scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + c_scalar_t2* frag_s_half2 = + reinterpret_cast(frag_s); #pragma unroll for (int i = 0; i < 8; i++) { - frag_s_half2[i] = Dtype::num2num2( - reinterpret_cast(&frag_s_half2[i])[idx]); + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); } } } @@ -1715,26 +2044,48 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && + b_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); if constexpr (!m_block_size_8) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } @@ -1758,7 +2109,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; - reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + if constexpr (!is_a_8bit) + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; __syncthreads(); } @@ -1768,21 +2120,22 @@ __global__ void Marlin( // only the last block in a slice actually writes the result write_result(last); slice_row = 0; - slice_col_par++; - slice_col++; + if (!in_part2) { + slice_col_par += gridDim.x; + } else { + slice_col_par++; + slice_col++; + } is_first_matmul_in_slice = true; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } + a_gl_rd += a_gl_rd_delta_o * slice_row; + b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading @@ -1791,12 +2144,28 @@ __global__ void Marlin( slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } - start_pipes(); } } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e9c96bb8b56cf..914227838558a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor? b_bias_or_none," - "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " - "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " + "Tensor? b_bias_or_none,Tensor b_scales, " + "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, " + "Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); // conditionally compiled so impl registration is in source file @@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin repack from GPTQ. ops.def( "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " - "SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); + "SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor"); // conditionally compiled so impl registrations are in source file // awq_marlin repack from AWQ. ops.def( "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " - "SymInt size_n, int num_bits) -> Tensor"); + "SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor"); + // conditionally compiled so impl registrations are in source file + + // preprocess W-int4A-fp8 weight for marlin kernel + ops.def( + "marlin_int4_fp8_preprocess(Tensor qweight, " + "Tensor? qzeros_or_none, bool inplace) -> Tensor"); // conditionally compiled so impl registrations are in source file // CUTLASS w4a8 GEMM diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index e54a9e2bc5e77..44aaa65218cc4 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] -- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod] +- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0550c2d9e2125..bacf6f37f2b08 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe -from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment @@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] TOP_KS = [2, 6] +MOE_MARLIN_QUANT_TEST_CONFIGS = [ + # AWQ-INT4 + {"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]}, + # GPTQ-INT4 + { + "b_type": scalar_types.uint4b8, + "support_act_order": True, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT8 + { + "b_type": scalar_types.uint8b128, + "support_act_order": True, + "group_blocks": [-1, 2, 4, 8], + }, + # FP8 + {"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]}, + # NVFP4 + {"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]}, + # MXFP4 + { + "a_type": [scalar_types.bfloat16], + "b_type": scalar_types.float4_e2m1f, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": [scalar_types.int8], + "b_type": scalar_types.uint4, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": [scalar_types.int8], + "b_type": scalar_types.uint4b8, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": [scalar_types.float8_e4m3fn], + "b_type": scalar_types.uint4b8, + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": [scalar_types.float8_e4m3fn], + "b_type": scalar_types.uint4, + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": [scalar_types.float8_e4m3fn], + "b_type": scalar_types.float4_e2m1f, + "c_type": [scalar_types.bfloat16], + "group_blocks": [2], + }, +] + FUSED_MOE_MNK_FACTORS = [ (1, 128, 128), (1, 2048, 128), @@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases(): m_list = [1, 123, 666] n_list = [128, 1024] k_list = [256, 2048] - e_list = [4, 12] + e_list = [5, 12] topk_list = [2, 3] ep_size_list = [1, 4] - dtype_list = [torch.bfloat16] - group_size_list = [-1, 32, 128] act_order_list = [True, False] - quant_type_list = [ - scalar_types.float4_e2m1f, - scalar_types.float8_e4m3fn, - scalar_types.uint4, - scalar_types.uint4b8, - scalar_types.uint8b128, - ] is_k_full_list = [True, False] all_combinations = itertools.product( + MOE_MARLIN_QUANT_TEST_CONFIGS, m_list, n_list, k_list, e_list, topk_list, ep_size_list, - dtype_list, - group_size_list, act_order_list, - quant_type_list, is_k_full_list, ) def is_invalid( - m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full + a_type, + b_type, + c_type, + group_blocks, + m, + n, + k, + e, + topk, + ep_size, + act_order, + is_k_full, ): - if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: - return False - if quant_type == scalar_types.float4_e2m1f: - if group_size not in [16, 32]: - return False - if dtype == torch.float16 and group_size == 32: - return False - if quant_type != scalar_types.float4_e2m1f and group_size == 16: + group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 + if group_size > 0 and k % group_size != 0: return False - # Filter act_order - if act_order: - if group_size in (-1, k, n): - return False - if quant_type not in [scalar_types.uint4b8]: - return False - elif not is_k_full: + if act_order and group_size in [-1, k, n]: + return False + if group_size in [k, n]: + return False + if not act_order and is_k_full: return False - return True + return a_type.size_bits < 16 or a_type is c_type cases = [] for case in all_combinations: - if is_invalid(*case): - cases.append(case) + quant_test_config, m, n, k, _, _, _, act_order, *_ = case + if act_order and not quant_test_config.get("support_act_order", False): + continue + + f16_types = [scalar_types.float16] + inner_combinations = itertools.product( + quant_test_config.get("a_type", f16_types), + [quant_test_config["b_type"]], + quant_test_config.get("c_type", f16_types), + quant_test_config["group_blocks"], + ) + + for sub_case in inner_combinations: + if ( + sub_case[0] == scalar_types.float8_e4m3fn + and current_platform.get_device_capability() not in [89, 120] + ): + continue + args = sub_case + (m, n, k) + case[4:] + if is_invalid(*args): + cases.append(args) return cases @@ -571,6 +640,7 @@ class MarlinMoEWeightData: qweight: torch.Tensor scales: torch.Tensor global_scale: torch.Tensor | None + a_scales_factor: torch.Tensor | None g_idx: torch.Tensor | None zeros: torch.Tensor | None sort_indices: torch.Tensor | None @@ -583,11 +653,20 @@ class MarlinMoEWeightData: group_size: int, act_order: bool | None = None, bias: torch.Tensor | None = None, + input_type: ScalarType = None, ) -> "MarlinMoEWeightData": assert w.ndim == 3 + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] k = w.shape[-1] + if input_type == scalar_types.int8: + input_dtype = torch.int8 + elif input_type == scalar_types.float8_e4m3fn: + input_dtype = torch.float8_e4m3fn + else: + input_dtype = w.dtype + w_ref_l: list[torch.Tensor] = [] qweight_l: list[torch.Tensor] = [] scales_l: list[torch.Tensor] = [] @@ -601,11 +680,13 @@ class MarlinMoEWeightData: if quant_type == scalar_types.float4_e2m1f: if group_size == 16: w_ref, qweight, scales, global_scale = ( - rand_marlin_weight_nvfp4_like(w[i], group_size) + rand_marlin_weight_nvfp4_like( + w[i], group_size, input_dtype=input_dtype + ) ) else: w_ref, qweight, scales = rand_marlin_weight_mxfp4_like( - w[i], group_size + w[i], group_size, input_dtype=input_dtype ) global_scale = None @@ -615,13 +696,18 @@ class MarlinMoEWeightData: if global_scale is not None: global_scale_l.append(global_scale) elif quant_type == scalar_types.float8_e4m3fn: - w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size) + w_ref, qweight, scales = marlin_quant_fp8_torch( + w[i], group_size, input_dtype=input_dtype + ) w_ref_l.append(w_ref.T) qweight_l.append(qweight) scales_l.append(scales) elif has_zp: w_ref, qweight, scales, zeros = awq_marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size + w[i].transpose(1, 0), + quant_type, + group_size, + input_dtype=input_dtype, ) w_ref_l.append(w_ref.T) @@ -631,7 +717,12 @@ class MarlinMoEWeightData: else: test_perm = torch.randperm(k) w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + w[i].transpose(1, 0), + quant_type, + group_size, + act_order, + test_perm, + input_dtype=input_dtype, ) w_ref_l.append(w_ref.T) @@ -652,11 +743,18 @@ class MarlinMoEWeightData: sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None marlin_bias = stack_and_dev(bias_l) if bias_l else None + a_scales_factor = None + if input_type == scalar_types.int8 and group_size != -1: + a_scales_factor = 1 / 4096 * scales.max().float() + scales = scales / scales.max() * 4096 + scales = scales.round().to(torch.int16).view(w.dtype) + return MarlinMoEWeightData( w_ref=w_ref, qweight=qweight, scales=scales, global_scale=global_scale, + a_scales_factor=a_scales_factor, g_idx=g_idx, zeros=zeros, sort_indices=sort_indices, @@ -666,28 +764,47 @@ class MarlinMoEWeightData: @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize( - ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), + ( + "a_type, b_type, c_type, group_blocks," + "m, n, k, e, topk, ep_size, act_order, is_k_full" + ), marlin_moe_generate_valid_test_cases(), ) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( - m: int, - n: int, - k: int, - e: int, - topk: int, - ep_size: int, - dtype: torch.dtype, - group_size: int, - act_order: bool, - quant_type: ScalarType, - is_k_full: bool, + a_type, + b_type, + c_type, + group_blocks, + m, + n, + k, + e, + topk, + ep_size, + act_order, + is_k_full, ): - torch.cuda.manual_seed(0) + torch.cuda.manual_seed(1) + group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 + + if c_type == scalar_types.float16: + dtype = torch.float16 + elif c_type == scalar_types.bfloat16: + dtype = torch.bfloat16 + else: + raise RuntimeError("unsupported c_type") + + if a_type == scalar_types.int8: + a_dtype = torch.int8 + elif a_type == scalar_types.float8_e4m3fn: + a_dtype = torch.float8_e4m3fn + else: + a_dtype = dtype a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 if ep_size > 1: local_e = e // ep_size @@ -700,11 +817,19 @@ def test_fused_marlin_moe( e_map = None w1_data = MarlinMoEWeightData.make( - w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order + w=w1, + quant_type=b_type, + group_size=group_size, + act_order=act_order, + input_type=a_type, ) w2_data = MarlinMoEWeightData.make( - w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order + w=w2, + quant_type=b_type, + group_size=group_size, + act_order=act_order, + input_type=a_type, ) score = torch.randn((m, e), device="cuda", dtype=dtype) @@ -712,8 +837,18 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe( - a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + torch_output = torch_experts( + a, + w1_data.w_ref, + w2_data.w_ref, + topk_weight=topk_weight, + topk_ids=topk_ids, + global_num_experts=e, + expert_map=e_map, + quant_dtype=a_dtype, + per_act_token_quant=True, ) marlin_output = fused_marlin_moe( @@ -733,15 +868,18 @@ def test_fused_marlin_moe( global_scale2=w2_data.global_scale, g_idx1=w1_data.g_idx, g_idx2=w2_data.g_idx, + input_global_scale1=w1_data.a_scales_factor, + input_global_scale2=w2_data.a_scales_factor, sort_indices1=w1_data.sort_indices, sort_indices2=w2_data.sort_indices, w1_zeros=w1_data.zeros, w2_zeros=w2_data.zeros, - quant_type_id=quant_type.id, + input_dtype=a_dtype, + quant_type_id=b_type.id, is_k_full=is_k_full, ) - torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) + torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0) @pytest.mark.flaky(reruns=2) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 0833115fcf301..59516db1b115d 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -5,6 +5,8 @@ Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. """ +import itertools + import pytest import torch @@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, ) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_quant_int8, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_bias, @@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( query_marlin_supported_quant_types, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like, ) @@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights, sort_weights, ) +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] @@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512] HQQ_SUPPORTED_GROUP_SIZES = [64] +MARLIN_REPACK_NK_FACTORS = [ + (4, 8), + (7, 5), + (13, 11), +] + MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), @@ -74,6 +84,64 @@ MNK_FACTORS = [ DTYPES = [torch.float16, torch.bfloat16] +DENSE_MARLIN_QUANT_TEST_CONFIGS = [ + # AWQ-INT4 + {"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]}, + # GPTQ-INT4 + { + "b_type": scalar_types.uint4b8, + "support_act_order": True, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT8 + { + "b_type": scalar_types.uint8b128, + "support_act_order": True, + "group_blocks": [-1, 2, 4, 8], + }, + # FP8 + {"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]}, + # NVFP4 + {"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]}, + # MXFP4 + { + "a_type": [scalar_types.bfloat16], + "b_type": scalar_types.float4_e2m1f, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": [scalar_types.int8], + "b_type": scalar_types.uint4, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": [scalar_types.int8], + "b_type": scalar_types.uint4b8, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": [scalar_types.float8_e4m3fn], + "b_type": scalar_types.uint4b8, + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": [scalar_types.float8_e4m3fn], + "b_type": scalar_types.uint4, + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": [scalar_types.float8_e4m3fn], + "b_type": scalar_types.float4_e2m1f, + "c_type": [scalar_types.bfloat16], + "group_blocks": [2], + }, +] + def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( @@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16): return torch.randn(shape, dtype=dtype, device="cuda") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) +def test_marlin_int4_fp8_preprocess_without_zp(): + qweight_unpacked = torch.randint( + 0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda" + ) + qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2] + qweight_packed = qweight_packed.to(torch.int8).view(torch.int32) + + cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed) + + torch_res = torch.where( + qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked + ) + torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2] + torch_res = torch_res.to(torch.int8).view(torch.int32) + + assert (cuda_res == torch_res).all() + + +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) +def test_marlin_int4_fp8_preprocess_awq(): + group_size = 128 + + qweight_unpacked = torch.randint( + 0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda" + ) + qzeros_unpacked = torch.randint( + 0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda" + ) + + qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2] + qweight_packed = qweight_packed.to(torch.int8).view(torch.int32) + qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2] + qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32) + + cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed) + + repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0) + torch_res = qweight_unpacked - repeated_zp + torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0] + torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2] + torch_res = torch_res.to(torch.int8).view(torch.int32) + + assert (cuda_res == torch_res).all() + + @pytest.mark.skipif( not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.", @@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16): @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("is_a_8bit", [True, False]) +@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS) def test_gptq_marlin_repack( - k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors + k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors ): - m_factor, n_factor, k_factor = mnk_factors + n_factor, k_factor = nk_factors size_k = k_chunk * k_factor size_n = n_chunk * n_factor + group_size = 128 # Filter act_order if act_order: @@ -109,6 +230,8 @@ def test_gptq_marlin_repack( return if group_size == size_k: return + if is_a_8bit: + return # Normalize group_size if group_size == -1: @@ -133,23 +256,19 @@ def test_gptq_marlin_repack( q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - weight_perm = get_weight_perm(quant_type.size_bits) + weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit) marlin_q_w_1 = marlin_weights( - q_w, size_k, size_n, quant_type.size_bits, weight_perm + q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit ) opcheck( torch.ops._C.gptq_marlin_repack, - (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits), + (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit), ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( - q_w_gptq, - sort_indices, - size_k, - size_n, - quant_type.size_bits, + q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit ) torch.cuda.synchronize() @@ -163,18 +282,15 @@ def test_gptq_marlin_repack( @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): - m_factor, n_factor, k_factor = mnk_factors +@pytest.mark.parametrize("is_a_8bit", [True, False]) +@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS) +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors): + n_factor, k_factor = nk_factors size_k = k_chunk * k_factor size_n = n_chunk * n_factor - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k + group_size = 128 # Create input b_weight = rand_data((size_k, size_n)) @@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format - weight_perm = get_weight_perm(quant_type.size_bits) + weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit) marlin_q_w_1 = marlin_weights( - q_w, size_k, size_n, quant_type.size_bits, weight_perm + q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit ) opcheck( - torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits) + torch.ops._C.awq_marlin_repack, + (q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit), ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( - q_w_awq, - size_k, - size_n, - quant_type.size_bits, + q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit ) torch.cuda.synchronize() torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) +def marlin_generate_valid_test_cases(): + all_combinations = itertools.product( + DENSE_MARLIN_QUANT_TEST_CONFIGS, + MNK_FACTORS, + MARLIN_N_CHUNKS, + MARLIN_K_CHUNKS, + ACT_ORDER_OPTS, + K_FULL_OPTS, + USE_ATOMIC_ADD_OPTS, + USE_FP32_REDUCE_OPTS, + ) + + def is_invalid( + a_type, + b_type, + c_type, + group_blocks, + size_m, + size_n, + size_k, + act_order, + is_k_full, + use_atomic_add, + use_fp32_reduce, + ): + if use_atomic_add: + if use_fp32_reduce: + return False + if ( + c_type == scalar_types.bfloat16 + and torch.cuda.get_device_capability()[0] < 9 + ): + return False + + group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 + if group_size > 0 and size_k % group_size != 0: + return False + + if act_order and group_size in [-1, size_k]: + return False + if group_size == size_k: + return False + if not act_order and is_k_full: + return False + + return a_type.size_bits < 16 or a_type is c_type + + cases = [] + for case in all_combinations: + quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case + size_m = mnk_factors[0] + size_n = mnk_factors[1] * n_chunk + size_k = mnk_factors[2] * k_chunk + + if act_order and not quant_test_config.get("support_act_order", False): + continue + + f16_types = [scalar_types.float16, scalar_types.bfloat16] + inner_combinations = itertools.product( + quant_test_config.get("a_type", f16_types), + [quant_test_config["b_type"]], + quant_test_config.get("c_type", f16_types), + quant_test_config["group_blocks"], + ) + + for sub_case in inner_combinations: + if ( + sub_case[0] == scalar_types.float8_e4m3fn + and current_platform.get_device_capability() not in [89, 120] + ): + continue + args = sub_case + (size_m, size_n, size_k) + case[4:] + if is_invalid(*args): + cases.append(args) + return cases + + @pytest.mark.skipif( not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.", ) -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) @pytest.mark.parametrize( - "group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES) + ( + "a_type, b_type, c_type, group_blocks," + "size_m, size_n, size_k, act_order, is_k_full," + "use_atomic_add, use_fp32_reduce" + ), + marlin_generate_valid_test_cases(), ) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) -@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) -@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) -@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -@pytest.mark.parametrize("dtype", DTYPES) def test_gptq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, + a_type, + b_type, + c_type, + group_blocks, + size_m, + size_n, + size_k, act_order, is_k_full, use_atomic_add, use_fp32_reduce, - dtype, ): - m_factor, n_factor, k_factor = mnk_factors - has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + has_zp = b_type in [scalar_types.uint4, scalar_types.uint8] - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor + group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 - if act_order: - if group_size == -1: - return - if group_size == size_k: - return - if has_zp: - return + if c_type == scalar_types.float16: + dtype = torch.float16 + elif c_type == scalar_types.bfloat16: + dtype = torch.bfloat16 + else: + raise RuntimeError("unsupported c_type") - if size_k % group_size != 0: - return + if a_type == scalar_types.int8: + a_dtype = torch.int8 + elif a_type == scalar_types.float8_e4m3fn: + a_dtype = torch.float8_e4m3fn + else: + a_dtype = dtype - a_input = rand_data((size_m, size_k), dtype) - b_weight = rand_data((size_k, size_n), dtype) - - if quant_type == scalar_types.float4_e2m1f: - if group_size not in [16, 32] or act_order: - return - if group_size == 32 and dtype == torch.float16: - return + a_input = rand_data((size_m, size_k), dtype=dtype) + b_weight = rand_data((size_k, size_n), dtype=dtype) + if b_type == scalar_types.float4_e2m1f: if group_size == 16: w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like( - b_weight.T, group_size + b_weight.T, group_size, input_dtype=a_dtype ) else: w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like( - b_weight.T, group_size + b_weight.T, group_size, input_dtype=a_dtype ) marlin_s2 = None g_idx = None sort_indices = None marlin_zp = None - elif quant_type == scalar_types.float8_e4m3fn: - if group_size not in [-1, 128]: - return - if act_order: - return - w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size) + elif b_type == scalar_types.float8_e4m3fn: + w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch( + b_weight.T, group_size, input_dtype=a_dtype + ) g_idx = None sort_indices = None marlin_zp = None marlin_s2 = None elif has_zp: - if group_size == 16: - return w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size + b_weight, b_type, group_size, input_dtype=a_dtype ) g_idx = None sort_indices = None marlin_s2 = None else: - if group_size == 16: - return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, act_order + b_weight, b_type, group_size, act_order, input_dtype=a_dtype ) + marlin_zp = None marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck( - torch.ops._C.gptq_marlin_gemm, - ( - a_input, - None, - marlin_q_w, - None, - marlin_s, - marlin_s2, - marlin_zp, - g_idx, - sort_indices, - workspace, - quant_type.id, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - is_k_full, - use_atomic_add, - use_fp32_reduce, - False, - ), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - ) + if a_type == scalar_types.int8: + a_input, a_scales = per_token_quant_int8(a_input) + a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1) + a_input_ref = a_input_ref.to(dtype) + + if group_size != -1: + a_scales = a_scales / 4096 * marlin_s.max() + a_scales = a_scales.float() + marlin_s = marlin_s / marlin_s.max() * 4096 + marlin_s = marlin_s.round().to(torch.int16).view(dtype) + elif a_type == scalar_types.float8_e4m3fn: + a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True) + a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1) + a_input_ref = a_input_ref.to(dtype) + else: + assert a_type.size_bits == 16 + a_input_ref = a_input + a_scales = None + + output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device) output = ops.gptq_marlin_gemm( a_input, - None, + output, marlin_q_w, None, marlin_s, + a_scales, marlin_s2, marlin_zp, g_idx, sort_indices, workspace, - quant_type, + b_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -352,12 +527,9 @@ def test_gptq_marlin_gemm( use_fp32_reduce=use_fp32_reduce, is_zp_float=False, ) - output_ref = torch.matmul(a_input, w_ref) - - torch.cuda.synchronize() + output_ref = torch.matmul(a_input_ref, w_ref) max_diff = compute_max_diff(output, output_ref) - assert max_diff < 0.04 @@ -507,6 +679,7 @@ def test_hqq_marlin_gemm( None, marlin_s, None, + None, marlin_zp, g_idx, g_idx_sort_indices, @@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input(): None, marlin_s, None, + None, marlin_zp, g_idx, sort_indices, @@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m): marlin_bias, marlin_s, None, + None, marlin_zp, g_idx, sort_indices, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 75e82f9314e74..98646442391fe 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -846,6 +846,13 @@ def torch_experts( or (expert_map is not None and global_num_experts == expert_map.shape[0]) ) + if quant_dtype in [torch.float16, torch.bfloat16]: + quant_dtype = None + quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None + if quant_input_only: + assert a1_scale is None and a2_scale is None + assert per_act_token_quant + M, K = a.shape topk = topk_ids.shape[1] @@ -863,6 +870,9 @@ def torch_experts( a, a1_scale, quant_dtype, per_act_token_quant, block_shape ) + if quant_input_only: + a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype) + num_experts = w1.shape[0] topk_ids = topk_ids.view(-1) @@ -882,6 +892,14 @@ def torch_experts( out[mask] = tmp2 @ w2[i].transpose(0, 1) if b_bias2 is not None: out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) + elif quant_input_only: + tmp1 = a[mask] @ w1[i].transpose(0, 1) + tmp2 = SiluAndMul()(tmp1) + tmp2, tmp2_scale = moe_kernel_quantize_input( + tmp2, None, quant_dtype, per_act_token_quant + ) + tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype) + out[mask] = tmp2 @ w2[i].transpose(0, 1) elif block_shape is not None: # block quantized assert ( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4a1bcc761f994..e60158898685a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -554,6 +554,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): b_q_weight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, + a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_zeros: torch.Tensor | None, g_idx: torch.Tensor | None, @@ -568,7 +569,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): use_fp32_reduce: bool = False, is_zp_float: bool = False, ) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + dtype = a.dtype + if dtype not in [torch.half, torch.bfloat16]: + dtype = b_scales.dtype + return torch.empty((size_m, size_n), device=a.device, dtype=dtype) @register_fake("_C::awq_dequantize") def _awq_dequantize_fake( @@ -1167,8 +1171,11 @@ def gptq_marlin_repack( size_k: int, size_n: int, num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: - return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + return torch.ops._C.gptq_marlin_repack( + b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit + ) if hasattr(torch.ops._C, "gptq_marlin_repack"): @@ -1180,6 +1187,7 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"): size_k: torch.SymInt, size_n: torch.SymInt, num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: pack_factor = 32 // num_bits marlin_tile_size = 16 @@ -1192,9 +1200,15 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"): # awq_marlin def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int + b_q_weight: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: - return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + return torch.ops._C.awq_marlin_repack( + b_q_weight, size_k, size_n, num_bits, is_a_8bit + ) if hasattr(torch.ops._C, "awq_marlin_repack"): @@ -1205,6 +1219,7 @@ if hasattr(torch.ops._C, "awq_marlin_repack"): size_k: torch.SymInt, size_n: torch.SymInt, num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: pack_factor = 32 // num_bits marlin_tile_size = 16 @@ -1221,6 +1236,7 @@ def gptq_marlin_moe_repack( size_k: int, size_n: int, num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 @@ -1231,7 +1247,7 @@ def gptq_marlin_moe_repack( ) for e in range(num_experts): output[e] = torch.ops._C.gptq_marlin_repack( - b_q_weight[e], perm[e], size_k, size_n, num_bits + b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit ) return output @@ -1242,6 +1258,7 @@ def awq_marlin_moe_repack( size_k: int, size_n: int, num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 @@ -1252,17 +1269,26 @@ def awq_marlin_moe_repack( ) for e in range(num_experts): output[e] = torch.ops._C.awq_marlin_repack( - b_q_weight[e], size_k, size_n, num_bits + b_q_weight[e], size_k, size_n, num_bits, is_a_8bit ) return output +def marlin_int4_fp8_preprocess( + qweight: torch.Tensor, + qzeros_or_none: torch.Tensor | None = None, + inplace: bool = False, +): + return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace) + + def gptq_marlin_gemm( a: torch.Tensor, c: torch.Tensor | None, b_q_weight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, + a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_zeros: torch.Tensor | None, g_idx: torch.Tensor | None, @@ -1283,6 +1309,7 @@ def gptq_marlin_gemm( b_q_weight, b_bias, b_scales, + a_scales, global_scale, b_zeros, g_idx, @@ -1600,7 +1627,7 @@ def allspark_repack_weight( if use asymmetric quantization, has_zp = True. Returns: - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : rearranged weight, scale, and optionally zero_point. """ K = qweight.shape[0] @@ -1683,7 +1710,7 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: @@ -2004,6 +2031,7 @@ def moe_wna16_marlin_gemm( b_qweight: torch.Tensor, b_bias: torch.Tensor | None, b_scales: torch.Tensor, + a_scales: torch.Tensor | None, global_scale: torch.Tensor | None, b_qzeros: torch.Tensor | None, g_idx: torch.Tensor | None, @@ -2025,6 +2053,9 @@ def moe_wna16_marlin_gemm( use_atomic_add: bool, use_fp32_reduce: bool, is_zp_float: bool, + thread_k: int = -1, + thread_n: int = -1, + blocks_per_sm: int = -1, ) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( input, @@ -2032,6 +2063,7 @@ def moe_wna16_marlin_gemm( b_qweight, b_bias, b_scales, + a_scales, global_scale, b_qzeros, g_idx, @@ -2053,6 +2085,9 @@ def moe_wna16_marlin_gemm( use_atomic_add, use_fp32_reduce, is_zp_float, + thread_k, + thread_n, + blocks_per_sm, ) @@ -2088,7 +2123,10 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe") input: torch.Tensor, output: torch.Tensor | None, b_qweight: torch.Tensor, + b_bias: torch.Tensor | None, b_scales: torch.Tensor, + a_scales: torch.Tensor | None, + global_scale: torch.Tensor | None, b_qzeros: torch.Tensor | None, g_idx: torch.Tensor | None, perm: torch.Tensor | None, @@ -2109,7 +2147,7 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe") use_atomic_add: bool, use_fp32_reduce: bool, is_zp_float: bool, - ) -> torch.Tensor: + ): return torch.empty( (size_m * top_k, size_n), dtype=input.dtype, device=input.device ) @@ -2583,7 +2621,7 @@ def onednn_scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) token_num = input.numel() // input.shape[-1] diff --git a/vllm/envs.py b/vllm/envs.py index 2ac457419a722..8ad62e1b8f508 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,6 +145,7 @@ if TYPE_CHECKING: VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 @@ -1122,6 +1123,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( os.environ.get("VLLM_MXFP4_USE_MARLIN", None) ), + # The activation dtype for marlin kernel + "VLLM_MARLIN_INPUT_DTYPE": env_with_choices( + "VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"] + ), # Whether to turn on the outlines cache for V1 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 0b0f59f673182..9c377db720132 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_moe_intermediate_size, - maybe_warn_marlin_atomic_add, + marlin_quant_input, ) from vllm.scalar_type import ScalarType, scalar_types @@ -65,6 +65,8 @@ def _fused_marlin_moe( activation_func: Callable[ [str, torch.Tensor, torch.Tensor], None ] = default_activation_func, + input_global_scale1: torch.Tensor | None = None, + input_global_scale2: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None, @@ -77,6 +79,7 @@ def _fused_marlin_moe( intermediate_cache13: torch.Tensor | None = None, intermediate_cache2: torch.Tensor | None = None, output: torch.Tensor | None = None, + input_dtype: torch.dtype | None = None, is_k_full: bool = True, ) -> torch.Tensor: assert hidden_states.ndim == 2 @@ -106,18 +109,22 @@ def _fused_marlin_moe( intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N)) - maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) - use_atomic_add = ( - hidden_states.dtype == torch.half - or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - ) + a_scales1 = None + gate_up_input = hidden_states + if input_dtype == torch.int8: + gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype) + if input_global_scale1 is not None: + a_scales1 = a_scales1 * input_global_scale1 + elif input_dtype == torch.float8_e4m3fn: + gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype) intermediate_cache1 = ops.moe_wna16_marlin_gemm( - hidden_states, + gate_up_input, intermediate_cache1, w1, bias1, w1_scale, + a_scales1, global_scale1, w1_zeros, g_idx1, @@ -136,7 +143,7 @@ def _fused_marlin_moe( size_n=2 * N, size_k=K, is_k_full=is_k_full, - use_atomic_add=use_atomic_add, + use_atomic_add=False, use_fp32_reduce=True, is_zp_float=False, ) @@ -151,12 +158,25 @@ def _fused_marlin_moe( if expert_map is not None: output.zero_() + a_scales2 = None + if input_dtype == torch.int8: + intermediate_cache2, a_scales2 = marlin_quant_input( + intermediate_cache2, input_dtype + ) + if input_global_scale2 is not None: + a_scales2 = a_scales2 * input_global_scale2 + elif input_dtype == torch.float8_e4m3fn: + intermediate_cache2, a_scales2 = marlin_quant_input( + intermediate_cache2, input_dtype + ) + output = ops.moe_wna16_marlin_gemm( intermediate_cache2, output, w2, bias2, w2_scale, + a_scales2, global_scale2, w2_zeros, g_idx2, @@ -175,7 +195,7 @@ def _fused_marlin_moe( size_n=K, size_k=N, is_k_full=is_k_full, - use_atomic_add=use_atomic_add, + use_atomic_add=False, use_fp32_reduce=True, is_zp_float=False, ) @@ -203,6 +223,8 @@ def fused_marlin_moe( ] = default_activation_func, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, expert_map: torch.Tensor | None = None, + input_global_scale1: torch.Tensor | None = None, + input_global_scale2: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None, @@ -216,6 +238,7 @@ def fused_marlin_moe( intermediate_cache2: torch.Tensor | None = None, is_k_full: bool = True, output: torch.Tensor | None = None, + input_dtype: torch.dtype | None = None, inplace: bool = False, ) -> torch.Tensor: """ @@ -287,6 +310,9 @@ def fused_marlin_moe( if M * topk / E / block_size_m < 0.9: break + if input_dtype is not None and input_dtype.itemsize == 1: + block_size_m = max(block_size_m, 16) + if global_num_experts == -1: global_num_experts = E sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( @@ -313,6 +339,8 @@ def fused_marlin_moe( num_tokens_post_padded=num_tokens_post_padded, activation=activation, activation_func=activation_func, + input_global_scale1=input_global_scale1, + input_global_scale2=input_global_scale2, global_scale1=global_scale1, global_scale2=global_scale2, g_idx1=g_idx1, @@ -325,6 +353,7 @@ def fused_marlin_moe( intermediate_cache13=intermediate_cache13, intermediate_cache2=intermediate_cache2, output=None, + input_dtype=input_dtype, is_k_full=is_k_full, ).view(-1, topk, K) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index f1943d4611877..95e4382c89d7a 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig): from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig, AWQMarlinLinearMethod, - AWQMoEMethod, + AWQMarlinMoEMethod, ) quant_args_marlin = AWQMarlinConfig( @@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: - return AWQMoEMethod(quant_args_marlin, layer.moe_config) + return AWQMarlinMoEMethod(quant_args_marlin, layer.moe) from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config config = { diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 0cf8b69f9f6ba..ab68c5dca52c0 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig): return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): # Lazy import to avoid circular import. - from .awq_marlin import AWQMarlinConfig, AWQMoEMethod + from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod from .moe_wna16 import MoeWNA16Config from .utils.marlin_utils import check_moe_marlin_supports_layer @@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig): awq_marlin_config = AWQMarlinConfig.from_config( marlin_compatible_config_dict ) - return AWQMoEMethod(awq_marlin_config, layer.moe_config) + return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config) return None def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 66945e2d2a7c8..d463e181fd2db 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, + get_marlin_input_dtype, + marlin_act_int8_process_scales, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_moe_permute_scales, @@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig): # num_bits -> type TYPE_MAP = { 4: scalar_types.uint4, - 8: scalar_types.uint8, } def __init__( @@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig): return AWQConfig.from_config(self.full_config).get_quant_method( layer, prefix ) - return AWQMarlinLinearMethod(self) + quant_method = AWQMarlinLinearMethod(self) + quant_method.input_dtype = get_marlin_input_dtype(prefix) + return quant_method elif isinstance(layer, FusedMoE): from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config @@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig): return MoeWNA16Config.from_config(self.full_config).get_quant_method( layer, prefix ) - return AWQMoEMethod(self, layer.moe_config) + moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config) + moe_quant_method.input_dtype = get_marlin_input_dtype(prefix) + return moe_quant_method return None @classmethod @@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQMarlinConfig) -> None: self.quant_config = quant_config + self.quant_type = scalar_types.uint4 + self.input_dtype = None def create_weights( self, @@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ) num_groups = input_size_per_partition // group_size + layer.num_groups = num_groups qzeros = PackedvLLMParameter( data=torch.empty( @@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase): # Allocate marlin workspace layer.workspace = marlin_make_workspace_new(device) + is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 + + if self.input_dtype == torch.float8_e4m3fn: + ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True) + layer.scales.data = layer.scales.data * 512 + # Repack weights from AWQ format to marlin format. marlin_qweight = ops.awq_marlin_repack( layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "qweight", marlin_qweight) @@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase): size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size, + is_a_8bit=is_a_8bit, ) + if self.input_dtype == torch.int8 and layer.num_groups > 1: + marlin_scales, input_global_scale = marlin_act_int8_process_scales( + marlin_scales + ) + layer.register_parameter( + "input_global_scale", Parameter(input_global_scale, requires_grad=False) + ) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. @@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "qzeros", marlin_zp) @@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase): quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, + input_global_scale=getattr(layer, "input_global_scale", None), bias=bias, + input_dtype=self.input_dtype, ) -class AWQMoEMethod(FusedMoEMethodBase): +class AWQMarlinMoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: AWQMarlinConfig, @@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase): super().__init__(moe) self.quant_config = quant_config if self.quant_config.weight_bits != 4: - raise ValueError("AWQMoEMethod only supports 4bit now.") + raise ValueError("AWQMarlinMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 + self.input_dtype = None self.use_marlin = True def create_weights( @@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + layer.input_dtype = self.input_dtype extra_weight_attrs.update( { "is_transposed": True, @@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase): num_groups_w13 = hidden_size // self.quant_config.group_size num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size + layer.num_groups_w13 = num_groups_w13 + layer.num_groups_w2 = num_groups_w2 # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. @@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device + is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 + + if self.input_dtype == torch.float8_e4m3fn: + ops.marlin_int4_fp8_preprocess( + layer.w13_qweight.view(-1, layer.w13_qweight.size(2)), + layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)), + inplace=True, + ) + ops.marlin_int4_fp8_preprocess( + layer.w2_qweight.view(-1, layer.w2_qweight.size(2)), + layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)), + inplace=True, + ) + layer.w13_scales.data = layer.w13_scales.data * 512 + layer.w2_scales.data = layer.w2_scales.data * 512 layer.w13_g_idx_sort_indices = torch.nn.Parameter( torch.empty((num_experts, 0), dtype=torch.int32, device=device), @@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase): size_k=layer.w13_qweight.shape[1], size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w13_qweight", marlin_w13_qweight) @@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase): size_k=layer.w2_qweight.shape[1], size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) @@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase): size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, + is_a_8bit=is_a_8bit, ) + if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1: + marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales( + marlin_w13_scales + ) + layer.register_parameter( + "w13_input_global_scale", + Parameter(w13_input_global_scale, requires_grad=False), + ) replace_parameter(layer, "w13_scales", marlin_w13_scales) @@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase): size_k=layer.intermediate_size_per_partition, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, + is_a_8bit=is_a_8bit, ) + if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1: + marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales( + marlin_w2_scales + ) + layer.register_parameter( + "w2_input_global_scale", + Parameter(w2_input_global_scale, requires_grad=False), + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) marlin_w13_zp = moe_awq_to_marlin_zero_points( @@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase): size_k=layer.w13_qzeros.shape[1], size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w13_qzeros", marlin_w13_zp) @@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase): size_k=layer.w2_qzeros.shape[1], size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) @@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase): router_logits, topk_weights, topk_ids, + input_global_scale1=getattr(layer, "w13_input_global_scale", None), + input_global_scale2=getattr(layer, "w2_input_global_scale", None), quant_type_id=self.quant_type.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, @@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase): w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, workspace=layer.workspace, + input_dtype=self.input_dtype, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index f9d8f5883680b..02086c3c0052d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig): if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) + return CompressedTensorsMoEMethod.get_moe_method( + self, layer, layer_name=prefix + ) return None def _add_fused_moe_to_target_scheme_map(self): @@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig): weight_quant: QuantizationArgs, input_quant: QuantizationArgs, format: str | None = None, + layer_name: str | None = None, ) -> "CompressedTensorsScheme": # use the per-layer format if defined, otherwise, use global format format = format if format is not None else self.quant_format @@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig): symmetric=weight_quant.symmetric, group_size=weight_quant.group_size, actorder=weight_quant.actorder, + layer_name=layer_name, ) act_quant_format = is_activation_quantization_format(format) @@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig): else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore - weight_quant=weight_quant, input_quant=input_quant, format=format + weight_quant=weight_quant, + input_quant=input_quant, + format=format, + layer_name=layer_name, ) # Raise error if device does not support the scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c7dfd1787cc8f..80ee443d4dd6a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, + get_marlin_input_dtype, + marlin_act_int8_process_scales, marlin_make_workspace_new, marlin_moe_permute_scales, ) @@ -101,7 +103,7 @@ __all__ = [ "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4Nvfp4MoeMethod", + "CompressedTensorsW4A4Nvfp4MoEMethod", "CompressedTensorsW4A8Int8MoEMethod", ] @@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 layer: torch.nn.Module, - prefix: str, + layer_name: str, ) -> "CompressedTensorsMoEMethod": # FusedMoE was made by combining multiple Linears so need to # make sure quantization config for Linear can target it quant_config._add_fused_moe_to_target_scheme_map() unfused_names = [ - prefix + proj_name + layer_name + proj_name for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] ] # TODO: refactor this to use expert_mapping and check all layer numbers @@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): "WNA16MoE is not supported with actorder=group/dynamic." ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config) + return CompressedTensorsWNA16MoEMethod( + quant_config, layer.moe_config, layer_name + ) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod( - quant_config, layer.moe_config + quant_config, layer.moe_config, layer_name ) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config) + return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name) elif ( quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant) ): - return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config) + return CompressedTensorsW8A8Fp8MoEMethod( + quant_config, layer.moe_config, layer_name + ) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) + return CompressedTensorsW8A8Int8MoEMethod( + quant_config, layer.moe_config, layer_name + ) elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): - return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config) + return CompressedTensorsW4A8Int8MoEMethod( + quant_config, layer.moe_config, layer_name + ) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" ) -class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self, moe: FusedMoEConfig): +class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): + def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support, ) @@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 + self.layer_name = layer_name + self.marlin_input_dtype = ( + get_marlin_input_dtype(layer_name) if self.use_marlin else None + ) self.flashinfer_moe_backend = None if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - " for CompressedTensorsW4A4Nvfp4MoeMethod." + " for CompressedTensorsW4A4Nvfp4MoEMethod." ) elif self.use_marlin: - logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.") + logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.") else: - logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.") + logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.") def create_weights( self, @@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ) if self.use_marlin: - prepare_moe_fp4_layer_for_marlin(layer) + prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) return # w13 if ( @@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ): if enable_eplb: raise NotImplementedError( - "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." + "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet." ) return flashinfer_trtllm_fp4_moe( @@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) @@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): assert expert_map is None, ( "Expert Parallelism / expert_map " "is currently not supported for " - "CompressedTensorsW4A4Nvfp4MoeMethod." + "CompressedTensorsW4A4Nvfp4MoEMethod." ) assert self.moe_quant_config is not None @@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 moe: FusedMoEConfig, + layer_name: str | None = None, ): super().__init__(moe) self.quant_config = quant_config @@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): or self.is_fp8_w8a8_sm100 ) self.disable_expert_map = False + self.layer_name = layer_name + self.marlin_input_dtype = ( + get_marlin_input_dtype(layer_name) if self.use_marlin else None + ) def create_weights( self, @@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) elif self.use_marlin: - prepare_moe_fp8_layer_for_marlin(layer, False) + prepare_moe_fp8_layer_for_marlin( + layer, False, input_dtype=self.marlin_input_dtype + ) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale @@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) @@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 moe: FusedMoEConfig, + layer_name: str | None = None, ): super().__init__(moe) self.quant_config = quant_config @@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 moe: FusedMoEConfig, + layer_name: str | None = None, ): super().__init__(moe) self.quant_config = quant_config @@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): self.strategy = config.strategy self.group_size = config.group_size self.actorder = config.actorder + self.layer_name = layer_name + self.marlin_input_dtype = get_marlin_input_dtype(layer_name) assert config.symmetric, "Only symmetric quantization is supported for MoE" if not ( @@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size + layer.num_groups_w13 = num_groups_w13 + layer.num_groups_w2 = num_groups_w2 + w13_scale = torch.nn.Parameter( torch.ones( num_experts, @@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_weight_g_idx.shape[0] device = layer.w13_weight_g_idx.device + is_a_8bit = ( + self.marlin_input_dtype is not None + and self.marlin_input_dtype.itemsize == 1 + ) + + if self.marlin_input_dtype == torch.float8_e4m3fn: + # NOTE: for non-zp quantization format only + ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True) + ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True) + layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512 + layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512 # when running models with grouped act order, # resort to g_idx values provided in checkpoint @@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[2], self.num_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[2], self.num_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) + # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_weight_scale, size_k=layer.w13_weight_packed.shape[2], size_n=layer.w13_weight_scale.shape[2], group_size=self.group_size, + is_a_8bit=is_a_8bit, ) + if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1: + marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales( + marlin_w13_scales + ) + layer.register_parameter( + "w13_input_global_scale", + torch.nn.Parameter(w13_input_global_scale, requires_grad=False), + ) replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_weight_scale, size_k=layer.w2_weight_scale.shape[1] * (self.group_size if self.group_size != -1 else self.packed_factor), size_n=layer.w2_weight_scale.shape[2], group_size=self.group_size, + is_a_8bit=is_a_8bit, ) + if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1: + marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales( + marlin_w2_scales + ) + layer.register_parameter( + "w2_input_global_scale", + torch.nn.Parameter(w2_input_global_scale, requires_grad=False), + ) replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) layer.workspace = marlin_make_workspace_new(device, 4) @@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): router_logits, topk_weights, topk_ids, + input_global_scale1=getattr(layer, "w13_input_global_scale", None), + input_global_scale2=getattr(layer, "w2_input_global_scale", None), quant_type_id=self.quant_type.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, @@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, workspace=layer.workspace, + input_dtype=self.marlin_input_dtype, is_k_full=self.is_k_full, ) @@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 moe: FusedMoEConfig, + layer_name: str | None = None, ): super().__init__(moe) self.quant_config = quant_config @@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 moe: FusedMoEConfig, + layer_name: str | None = None, ): super().__init__(moe) self.has_bias = self.moe.has_bias diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 2267395fe67d3..7f4dad70287bd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel, ) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( + MarlinLinearKernel, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, marlin_repeat_scales_on_all_ranks, ) from vllm.model_executor.parameter import ( @@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): group_size: int | None = None, symmetric: bool | None = True, actorder: ActivationOrdering | None = None, + layer_name: str | None = None, ): self.pack_factor = 32 // num_bits self.strategy = strategy self.symmetric = symmetric self.group_size = -1 if group_size is None else group_size self.has_g_idx = actorder == ActivationOrdering.GROUP + self.layer_name = layer_name if self.group_size == -1 and self.strategy != "channel": raise ValueError( @@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) + if isinstance(kernel_type, MarlinLinearKernel): + input_dtype = get_marlin_input_dtype(self.layer_name) + if input_dtype is not None: + mp_linear_kernel_config.act_type = input_dtype + # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size row_parallel = input_size != input_size_per_partition diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7dfc8a9c36c3e..48223c9f103ea 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_weight_tensor_strategy, validate_fp8_block_shape, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, @@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() - return Fp8LinearMethod(self) + quant_method = Fp8LinearMethod(self) + quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return quant_method elif isinstance(layer, FusedMoE): if is_layer_skipped( prefix=prefix, @@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) - return Fp8MoEMethod(self, layer) + moe_quant_method = Fp8MoEMethod(self, layer) + moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return moe_quant_method elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization + self.marlin_input_dtype = None self.use_marlin = ( not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN @@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase): ) if self.use_marlin: - prepare_fp8_layer_for_marlin(layer, size_k_first) + prepare_fp8_layer_for_marlin( + layer, size_k_first, input_dtype=self.marlin_input_dtype + ) # Activations not quantized for marlin. del layer.input_scale return @@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase): workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, + input_dtype=self.marlin_input_dtype, bias=bias, ) @@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.block_quant, layer.moe_parallel_config ) + self.marlin_input_dtype = None self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: @@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight.data = w13_weight.data if self.use_marlin: - prepare_moe_fp8_layer_for_marlin(layer, False) + prepare_moe_fp8_layer_for_marlin( + layer, False, input_dtype=self.marlin_input_dtype + ) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale @@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 77b15db373a3a..56034e11329dc 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, + get_marlin_input_dtype, + marlin_act_int8_process_scales, marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, @@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig): return MoeWNA16Config.from_config(self.full_config).get_quant_method( layer, prefix ) - return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod) - return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) + moe_quant_method = get_moe_quant_method( + self, layer, prefix, GPTQMarlinMoEMethod + ) + if moe_quant_method is None: + return None + moe_quant_method.input_dtype = get_marlin_input_dtype(prefix) + return moe_quant_method + + quant_method = get_linear_quant_method( + self, layer, prefix, GPTQMarlinLinearMethod + ) + if quant_method is None: + return None + quant_method.input_dtype = get_marlin_input_dtype(prefix) + return quant_method @classmethod def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): @@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config + self.input_dtype = None + self.quant_type = self.quant_config.quant_type # Verify supported on platform. verify_marlin_supported( @@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + input_dtype = self.input_dtype mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), @@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): output_size_per_partition, ), weight_type=self.quant_config.quant_type, - act_type=params_dtype, + act_type=params_dtype if input_dtype is None else input_dtype, group_size=self.quant_config.group_size, zero_points=False, has_g_idx=self.quant_config.desc_act, @@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): self.quant_type = scalar_types.uint8b128 else: raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") + self.input_dtype = None self.use_marlin = True def create_weights( @@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + layer.input_dtype = self.input_dtype + is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 + + if is_a_8bit: + assert self.quant_type == scalar_types.uint4b8, ( + "W8A8-INT8 is not supported by marlin kernel." + ) + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") self.is_k_full = (not self.quant_config.desc_act) or ( @@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): scales_size2 = 1 strategy = FusedMoeWeightScaleSupported.CHANNEL.value + layer.num_groups_w13 = scales_size13 + layer.num_groups_w2 = scales_size2 + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter( @@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): layer.workspace = marlin_make_workspace_new(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 + + if is_a_8bit: + assert self.quant_type == scalar_types.uint4b8, ( + "W8A8-INT8 is not supported by marlin kernel." + ) + + if self.input_dtype == torch.float8_e4m3fn: + ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True) + ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True) + layer.w13_scales.data = layer.w13_scales.data * 512 + layer.w2_scales.data = layer.w2_scales.data * 512 + # Process act_order if self.quant_config.desc_act: # Get sorting based on g_idx @@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[2], self.quant_config.quant_type.size_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( @@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[2], self.quant_config.quant_type.size_bits, + is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales @@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, + is_a_8bit=is_a_8bit, ) + if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1: + marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales( + marlin_w13_scales + ) + layer.register_parameter( + "w13_input_global_scale", + torch.nn.Parameter(w13_input_global_scale, requires_grad=False), + ) + replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, @@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ), size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, + is_a_8bit=is_a_8bit, ) + if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1: + marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales( + marlin_w2_scales + ) + layer.register_parameter( + "w2_input_global_scale", + torch.nn.Parameter(w2_input_global_scale, requires_grad=False), + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) if hasattr(layer, "w13_bias") and layer.w13_bias is not None: @@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): router_logits, topk_weights, topk_ids, + input_global_scale1=getattr(layer, "w13_input_global_scale", None), + input_global_scale2=getattr(layer, "w2_input_global_scale", None), quant_type_id=self.quant_type.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, @@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): sort_indices2=layer.w2_g_idx_sort_indices, workspace=layer.workspace, is_k_full=self.is_k_full, + input_dtype=self.input_dtype, ) diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 5fb67c35378be..fad8cb10fa8ac 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase): bias, scales, None, + None, zeros, layer.g_idx, layer.g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index ac21286eeffac..faaa45b861de7 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, + marlin_act_int8_process_scales, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace_new, @@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ) from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = getattr(layer, self.w_q_name).device c = self.config + is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1 + + if is_a_8bit: + assert c.weight_type == scalar_types.uint4b8, ( + "W8A8 is not supported by marlin kernel." + ) + + if c.act_type == torch.float8_e4m3fn: + ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True) + getattr(layer, self.w_s_name).data = ( + getattr(layer, self.w_s_name).data * 512 + ) row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) @@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel): size_k=c.partition_weight_shape[0], size_n=c.partition_weight_shape[1], num_bits=c.weight_type.size_bits, + is_a_8bit=is_a_8bit, ) return x @@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel): size_k=c.partition_weight_shape[0], size_n=c.partition_weight_shape[1], group_size=c.group_size, + is_a_8bit=is_a_8bit, ) + + if c.group_size == -1: + num_groups = 1 + else: + num_groups = c.partition_weight_shape[0] // c.group_size + + if c.act_type == torch.int8 and num_groups > 1: + x.data, input_global_scale = marlin_act_int8_process_scales(x.data) + layer.register_parameter( + "input_global_scale", + torch.nn.Parameter(input_global_scale, requires_grad=False), + ) + else: + layer.input_global_scale = None return x if c.has_g_idx: @@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel): size_k=grouped_k, size_n=c.partition_weight_shape[1], num_bits=c.weight_type.size_bits, + is_a_8bit=is_a_8bit, ), ) else: @@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel): # `process_weights_after_loading` will ensure w_zp and w_gidx are not # None for marlin + return apply_gptq_marlin_linear( input=x, weight=w_q, @@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel): input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], is_k_full=self.is_k_full, + input_global_scale=getattr(layer, "input_global_scale", None), bias=bias, + input_dtype=c.act_type, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 80f8e3a03e7cf..709c86175477a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( select_cutlass_fp8_gemm_impl, swap_w13_to_w31, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, @@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig): # now, the layer is quantized, handle it here if isinstance(layer, LinearBase): - return self.LinearMethodCls(self) + quant_method = self.LinearMethodCls(self) + if getattr(quant_method, "backend", "") == "marlin": + quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return quant_method elif isinstance(layer, FusedMoE): - return self.FusedMoEMethodCls(quant_config=self, layer=layer) + quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer) + if getattr(quant_method, "backend", "") == "marlin": + quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return quant_method return None @@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config + self.marlin_input_dtype = None self.backend = "none" if envs.VLLM_NVFP4_GEMM_BACKEND is None: @@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, bias=bias, + input_dtype=self.marlin_input_dtype, ) output_dtype = x.dtype @@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin + self.marlin_input_dtype = None self.flashinfer_moe_backend = None if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() @@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - workspace=layer.workspace, + input_dtype=self.marlin_input_dtype, ) elif self.allow_flashinfer: diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index bc241ac692e23..d271e56e08568 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_fp4_layer_for_marlin, ) @@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig): if current_platform.is_xpu(): return IpexMxfp4MoEMethod(layer.moe_config) else: - return Mxfp4MoEMethod(layer.moe_config) + quant_method = Mxfp4MoEMethod(layer.moe_config) + quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return quant_method elif isinstance(layer, Attention): # TODO: Add support for MXFP4 Attention. logger.debug_once( @@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + + self.marlin_input_dtype = None self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size @@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer): if self.mxfp4_backend == Mxfp4Backend.MARLIN: - prepare_moe_fp4_layer_for_marlin(layer) + prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) elif ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 @@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): global_num_experts=global_num_experts, activation=activation, expert_map=expert_map, + input_dtype=self.marlin_input_dtype, ) assert _can_support_mxfp4( diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 071fb4ba16867..14337ee1d7bee 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -9,6 +9,11 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_quant_int8, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -286,10 +291,10 @@ def get_scale_perms(): def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int + s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False ) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: + if group_size < size_k and group_size != -1 and not is_a_8bit: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] @@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: return s.reshape(*origin_shape).contiguous() +def marlin_act_int8_process_scales(s: torch.Tensor): + a_scales_scale_factor = 1 / 4096 * s.max().float() + s = s / s.max() * 4096 + s = s.round().to(torch.int16).view(s.dtype) + return s, a_scales_scale_factor + + def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, + s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False ): num_experts = s.shape[0] output = torch.empty( @@ -319,12 +328,12 @@ def marlin_moe_permute_scales( ) for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit) return output def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False ) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA @@ -339,7 +348,8 @@ def marlin_zero_points( else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + if not is_a_8bit: + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() zp = zp.reshape((-1, size_n)).contiguous() zp = pack_cols(zp, num_bits, size_k, size_n) @@ -347,7 +357,11 @@ def marlin_zero_points( def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int + q_zp_packed: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + is_a_8bit: bool = False, ) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. @@ -366,12 +380,16 @@ def awq_to_marlin_zero_points( q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() q_zp = q_zp.reshape((-1, size_n)).contiguous() - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit) return marlin_zp def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int + q_zp_packed: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + is_a_8bit: bool = False, ): num_experts = q_zp_packed.shape[0] output = torch.empty( @@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points( + q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit + ) return output @@ -432,6 +452,48 @@ def should_use_atomic_add_reduce( return True +_quant_fp8_method: QuantFP8 | None = None + + +def get__quant_fp8_method() -> QuantFP8: + global _quant_fp8_method + if _quant_fp8_method is None: + _quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN) + return _quant_fp8_method + + +def get_marlin_input_dtype(prefix): + if envs.VLLM_MARLIN_INPUT_DTYPE is None: + return + elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8": + return torch.int8 + elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8": + if not current_platform.is_device_capability( + 89 + ) and not current_platform.is_device_capability(120): + raise ValueError( + "Marlin W4A8-FP8 only support SM89 or SM120 device " + "(It is slower than Marlin W4A16 on other devices). " + "You can consider using W4A8-INT8 instead" + "(set VLLM_MARLIN_INPUT_DTYPE=int8)." + ) + + _ = get__quant_fp8_method() + return torch.float8_e4m3fn + else: + return + + +def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype): + x = x.reshape(-1, x.shape[-1]) + if quant_dtype == torch.int8: + return per_token_quant_int8(x) + elif quant_dtype == torch.float8_e4m3fn: + return get__quant_fp8_method()(x) + else: + raise ValueError(f"unsupported quant_dtype {quant_dtype}") + + def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, @@ -444,8 +506,10 @@ def apply_gptq_marlin_linear( output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, + input_global_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, + input_dtype: torch.dtype | None = None, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition,) @@ -458,12 +522,27 @@ def apply_gptq_marlin_linear( dtype=input.dtype, ) + a_scales = None + if input_dtype == torch.int8: + assert wtype == scalar_types.uint4b8, ( + "W8A8-INT8 is not supported by marlin kernel." + ) + reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) + a_scales = a_scales * input_global_scale + elif input_dtype == torch.float8_e4m3fn: + assert wtype == scalar_types.uint4b8, ( + "INT8 weight + FP8 activation is not supported." + ) + + reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) + output = ops.gptq_marlin_gemm( reshaped_x, None, weight, bias, weight_scale, + a_scales, None, weight_zp, g_idx, @@ -493,8 +572,10 @@ def apply_awq_marlin_linear( quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, + input_global_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, + input_dtype: torch.dtype | None = None, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition,) @@ -507,12 +588,20 @@ def apply_awq_marlin_linear( dtype=input.dtype, ) + a_scales = None + if input_dtype == torch.int8: + reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) + a_scales = a_scales * input_global_scale + elif input_dtype == torch.float8_e4m3fn: + reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) + output = ops.gptq_marlin_gemm( reshaped_x, None, weight, bias, weight_scale, + a_scales, None, weight_zp, g_idx, @@ -538,8 +627,10 @@ def apply_rtn_marlin_linear( quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, + input_global_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, + input_dtype: torch.dtype | None = None, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition,) @@ -552,12 +643,20 @@ def apply_rtn_marlin_linear( dtype=input.dtype, ) + a_scales = None + if input_dtype == torch.int8: + reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) + a_scales = a_scales * input_global_scale + elif input_dtype == torch.float8_e4m3fn: + reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) + output = ops.gptq_marlin_gemm( reshaped_x, None, weight, bias, weight_scale, + a_scales, None, None, None, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 842fb9b62267a..b94d5bbf36540 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, + marlin_quant_input, should_use_atomic_add_reduce, ) from vllm.platforms import current_platform @@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales): # convert to half first, we would convert to fp8 later marlin_scales = marlin_scales.to(torch.half) - # 8 is the number of scale number using by one thread - marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) - marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1 - ) - # fit the layout of fp8 dequantization marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( marlin_scales.size(0), -1 @@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales): return marlin_scales -def mxfp4_marlin_process_scales(marlin_scales): - # 8 is the number of scale number using by one thread - marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) - marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1 - ) - +def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None): # fit the layout of fp8 dequantization - marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1 - ) + if input_dtype is None or input_dtype.itemsize == 2: + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + if input_dtype == torch.float8_e4m3fn: + marlin_scales = marlin_scales.view(torch.uint8) + assert marlin_scales.max() <= 249 + # exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6 + marlin_scales = marlin_scales + 6 + marlin_scales = marlin_scales.view(torch.float8_e8m0fnu) return marlin_scales @@ -99,6 +96,7 @@ def apply_fp4_marlin_linear( size_n: int, size_k: int, bias: torch.Tensor | None = None, + input_dtype: torch.dtype | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: # For GPUs that lack FP4 hardware support, we can leverage the @@ -111,12 +109,24 @@ def apply_fp4_marlin_linear( m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype ) + inputs = reshaped_x + a_scales = None + is_nvfp4 = weight_scale_2 is not None + if input_dtype is not None and input_dtype.itemsize == 1: + if is_nvfp4: + raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") + elif input_dtype != torch.float8_e4m3fn: + raise RuntimeError("MXFP4 weight + INT8 activation is not supported.") + + inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) + output = ops.gptq_marlin_gemm( - a=reshaped_x, + a=inputs, c=None, b_q_weight=weight, b_bias=bias, b_scales=weight_scale, + a_scales=a_scales, global_scale=weight_scale_2, b_zeros=None, g_idx=None, @@ -133,7 +143,9 @@ def apply_fp4_marlin_linear( return output.reshape(out_shape) -def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: +def prepare_fp4_layer_for_marlin( + layer: torch.nn.Module, input_dtype: torch.dtype | None = None +) -> None: logger.warning_once( "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " @@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: perm = torch.empty(0, dtype=torch.int, device=device) qweight = layer.weight.view(torch.int32).T.contiguous() + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 marlin_qweight = ops.gptq_marlin_repack( b_q_weight=qweight, perm=perm, size_k=part_size_k, size_n=part_size_n, num_bits=4, + is_a_8bit=is_a_8bit, ) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) @@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: weight_scale = weight_scale.to(param_dtype) weight_scale = marlin_permute_scales( - s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=group_size, + is_a_8bit=is_a_8bit, ) if is_nvfp4: @@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) else: - weight_scale = mxfp4_marlin_process_scales(weight_scale) + weight_scale = mxfp4_marlin_process_scales( + weight_scale, input_dtype=input_dtype + ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) if hasattr(layer, "bias") and layer.bias is not None: @@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: return -def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: +def prepare_moe_fp4_layer_for_marlin( + layer: torch.nn.Module, input_dtype: torch.dtype | None = None +) -> None: logger.warning_once( "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " @@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: param_dtype = layer.params_dtype layer.workspace = marlin_make_workspace_new(device, 4) perm = torch.empty(0, dtype=torch.int, device=device) + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 # WEIGHT # Repack weights to marlin format @@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: qweight = weight[i].view(torch.int32).T.contiguous() marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + is_a_8bit=is_a_8bit, ) tensor_list.append(marlin_qweight) @@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: scale = scales[i].T marlin_scales = marlin_permute_scales( - s=scale, size_k=size_k, size_n=size_n, group_size=group_size + s=scale, + size_k=size_k, + size_n=size_n, + group_size=group_size, + is_a_8bit=is_a_8bit, ) if is_nvfp4: marlin_scales = nvfp4_marlin_process_scales(marlin_scales) else: - marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + marlin_scales = mxfp4_marlin_process_scales( + marlin_scales, input_dtype=input_dtype + ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: setattr(layer, name, bias) -def rand_marlin_weight_nvfp4_like(weight, group_size): +def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None): + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + + assert not is_a_8bit, "NVFP4 weight + INT8/FP8 activation is not supported." assert group_size > 0 size_n, size_k = weight.shape device = weight.device @@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): size_k=size_k, size_n=size_n, num_bits=4, + is_a_8bit=is_a_8bit, ) marlin_scales = marlin_permute_scales( - s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size, + is_a_8bit=is_a_8bit, ) marlin_scales = nvfp4_marlin_process_scales(marlin_scales) @@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): return weight_ref.T, marlin_qweight, marlin_scales, global_scale -def rand_marlin_weight_mxfp4_like(weight, group_size): +def rand_marlin_weight_mxfp4_like(weight, group_size, input_dtype=None): + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + if is_a_8bit: + assert input_dtype == torch.float8_e4m3fn, ( + "MXFP4 weight + INT8 activation is not supported." + ) + assert group_size > 0 size_n, size_k = weight.shape device = weight.device scales = torch.randint( - 100, - 125, + 110, + 120, (size_n, size_k // group_size), dtype=torch.uint8, device=weight.device, @@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): ).view(size_n, size_k) weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) + perm = torch.empty(0, dtype=torch.int, device=device) + fp4_weight = fp4_weight.view(torch.int32).T.contiguous() marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), - perm=torch.empty(0, dtype=torch.int, device=device), + b_q_weight=fp4_weight, + perm=perm, size_k=size_k, size_n=size_n, num_bits=4, + is_a_8bit=is_a_8bit, ) marlin_scales = marlin_permute_scales( - s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size, + is_a_8bit=is_a_8bit, ) - marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + marlin_scales = mxfp4_marlin_process_scales(marlin_scales, input_dtype=input_dtype) return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 8c96848a85397..e6b4f567caea4 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, + marlin_quant_input, should_use_atomic_add_reduce, ) from vllm.platforms import current_platform @@ -45,6 +46,7 @@ def apply_fp8_marlin_linear( size_n: int, size_k: int, bias: torch.Tensor | None, + input_dtype: torch.dtype | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the @@ -57,12 +59,21 @@ def apply_fp8_marlin_linear( m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype ) + inputs = reshaped_x + a_scales = None + if input_dtype is not None and input_dtype.itemsize == 1: + if input_dtype != torch.float8_e4m3fn: + raise RuntimeError("FP8 weight + INT8 activation is not supported.") + + inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) + output = ops.gptq_marlin_gemm( a=reshaped_x, c=None, b_q_weight=weight, b_bias=bias, b_scales=weight_scale, + a_scales=a_scales, global_scale=None, b_zeros=None, g_idx=None, @@ -80,7 +91,9 @@ def apply_fp8_marlin_linear( def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, size_k_first: bool = True + layer: torch.nn.Module, + size_k_first: bool = True, + input_dtype: torch.dtype | None = None, ) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " @@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin( marlin_scales = marlin_permute_scales( s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size ) - marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + if input_dtype != torch.float8_e4m3fn: + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) if hasattr(layer, "bias") and layer.bias is not None: @@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin( - layer: torch.nn.Module, size_k_first: bool = True + layer: torch.nn.Module, + size_k_first: bool = True, + input_dtype: torch.dtype | None = None, ) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " @@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin( tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - scales = fp8_fused_exponent_bias_into_scales(scales) + if input_dtype != torch.float8_e4m3fn: + scales = fp8_fused_exponent_bias_into_scales(scales) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) @@ -318,7 +335,11 @@ def pack_fp8_to_int32( return int32_tensor.T.contiguous() if size_k_first else int32_tensor -def marlin_quant_fp8_torch(weight, group_size): +def marlin_quant_fp8_torch(weight, group_size, input_dtype=None): + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + if is_a_8bit: + assert input_dtype == torch.float8_e4m3fn + size_n, size_k = weight.shape device = weight.device @@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size): weight_ref = fp8_weight.to(weight.dtype) * repeated_scales packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + perm = torch.empty(0, dtype=torch.int, device=device) marlin_qweight = ops.gptq_marlin_repack( b_q_weight=packed_weight, - perm=torch.empty(0, dtype=torch.int, device=device), + perm=perm, size_k=size_k, size_n=size_n, num_bits=8, + is_a_8bit=is_a_8bit, ) marlin_scales = marlin_permute_scales( - s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size + s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size, + is_a_8bit=is_a_8bit, ) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 89756c45ef556..9162afe03da90 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -5,7 +5,8 @@ import numpy as np import torch -from vllm.scalar_type import ScalarType +from vllm import _custom_ops as ops +from vllm.scalar_type import ScalarType, scalar_types from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points from .quant_utils import ( @@ -29,13 +30,19 @@ class MarlinWorkspace: self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): +def marlin_permute_weights( + q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False +): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + if is_a_8bit: + # Permute weights to 32x32 marlin tiles + q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile)) + else: + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.reshape((size_k // tile, size_n * tile)) @@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): return q_w -def marlin_weights(q_w, size_k, size_n, num_bits, perm): +def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False): # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit) # Pack pack_factor = get_pack_factor(num_bits) @@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): return q_packed -def get_weight_perm(num_bits: int): +def get_weight_perm(num_bits: int, is_a_8bit: bool = False): perm_list: list[int] = [] - for i in range(32): - perm1: list[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) + if is_a_8bit: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + 4 * (i % 4 + 4), + 4 * (i % 4 + 4) + 1, + 4 * (i % 4 + 4) + 2, + 4 * (i % 4 + 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(2): + perm_list.extend([p + 512 * j for p in perm1]) + else: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) perm = np.array(perm_list) if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 1, 2, 3]) + else: + interleave = np.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) @@ -99,7 +131,10 @@ def marlin_quantize( group_size: int, act_order: bool, test_perm: torch.Tensor | None = None, + input_dtype: torch.dtype | None = None, ): + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -120,9 +155,15 @@ def marlin_quantize( q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + weight_perm = get_weight_perm(num_bits, is_a_8bit) + marlin_q_w = marlin_weights( + q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit + ) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit) + + if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8: + ops.marlin_int4_fp8_preprocess(marlin_q_w, inplace=True) + marlin_s = marlin_s * 512 # Create result res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] @@ -132,7 +173,13 @@ def marlin_quantize( return res_list -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): +def awq_marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + input_dtype: torch.dtype | None = None, +): + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 size_k, size_n = w.shape # Normalize group_size @@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int # Quantize with zp w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4: + repeated_zp = zp.repeat_interleave(group_size, 0) + q_w_old = q_w + q_w = q_w_old - repeated_zp + q_w[q_w < 0] = 15 - q_w_old[q_w < 0] + s = s * 512 + # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit) + marlin_q_w = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit=is_a_8bit + ) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit) + marlin_zp = marlin_zero_points( + zp, num_groups, size_n, quant_type.size_bits, is_a_8bit=is_a_8bit + ) # Create result res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]