[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin 2025-11-29 23:19:33 +08:00 committed by GitHub
parent fa59fe417f
commit 1656ad3704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 4371 additions and 2240 deletions

View File

@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs. # Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that # Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet. # are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
# #
@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_GEN_SCRIPT set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE marlin_generation_result RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else() else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin generate script hash" FORCE) CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
message(STATUS "Marlin generation completed successfully.") message(STATUS "Marlin generation completed successfully.")
endif() endif()
else() else()
message(STATUS "Marlin generation script has not changed, skipping generation.") message(STATUS "Marlin generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
endif()
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu") "csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX # moe marlin arches
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") # note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) if (MARLIN_MOE_ARCHS)
# #
@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_MARLIN_GEN_SCRIPT set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE moe_marlin_generation_result RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: " "\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else() else()
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE) CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.") message(STATUS "Marlin MOE generation completed successfully.")
endif() endif()
@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.") message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif() endif()
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MOE_WNAA16_MARLIN_SRC}" SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}") CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif() endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_FP8_SRC}"
CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_FP8_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else() else()

View File

@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight=w_q, b_q_weight=w_q,
b_bias=None, b_bias=None,
b_scales=w_s, b_scales=w_s,
a_scales=None,
global_scale=None, global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,

View File

@ -263,7 +263,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
@ -273,7 +273,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,

View File

@ -1 +1,2 @@
kernel_*.cu sm*_kernel_*.cu
kernel_selector.h

View File

@ -4,134 +4,282 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
// clang-format off for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = [] all_template_str_list = []
for config in config_list:
for group_blocks, m_blocks, thread_configs in itertools.product( s_type = config["s_type"]
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8",
"vllm::kU8B128",
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
template_str = jinja2.Template(TEMPLATE).render( template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype, a_type_id=f"vllm::{a_type}.id()",
w_type_id=scalar_type + ".id()", b_type_id=f"vllm::{b_type}.id()",
s_type_id=s_type + ".id()", c_type_id=f"vllm::{c_type}.id()",
threads=threads, s_type_id=f"vllm::{s_type}.id()",
thread_m_blocks=max(m_blocks, 1), **config,
thread_n_blocks=n_blocks, )
thread_k_blocks=k_blocks, all_template_str_list.append(template_str)
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages", conditions = [
group_blocks=group_blocks, f"a_type == vllm::{a_type}",
is_zp_float=False, f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
) )
all_template_str_list.append(template_str) kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()

View File

@ -11,8 +11,9 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \
@ -20,12 +21,13 @@
const float *__restrict__ topk_weights_ptr, int top_k, \ const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the

File diff suppressed because it is too large Load Diff

View File

@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based // For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices. // on the given "perm" indices.
template <int moe_block_size> template <int moe_block_size>
@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp, bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) { int is_zp_float, bool is_a_8bit) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4; int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2; int sh_bias_size = tb_n * 2;
@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k, int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float, bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) { int max_shared_mem, bool is_a_8bit) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
} }
// Check that pipeline fits into cache // Check that pipeline fits into cache
int cache_size = get_kernel_cache_size( int cache_size =
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); prob_n, prob_k, num_bits, group_size, has_act_order,
return cache_size + 512 <= max_shared_mem; is_k_full, has_zp, is_zp_float, is_a_8bit);
return cache_size <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinFuncPtr get_marlin_kernel(
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ const vllm::ScalarType a_type, const vllm::ScalarType b_type,
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ const vllm::ScalarType c_type, const vllm::ScalarType s_type,
thread_n_blocks == THREAD_N_BLOCKS && \ int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
thread_k_blocks == THREAD_K_BLOCKS && \ bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ int threads, bool is_zp_float) {
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ int num_bits = b_type.size_bits();
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4) #include "kernel_selector.h"
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
template <typename scalar_t> exec_config_t determine_exec_config(
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
int prob_n, int prob_k, int thread_m_blocks, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
bool m_block_size_8, int num_bits, int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
int group_size, bool has_act_order, bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_zp_float, int max_shared_mem) { bool is_a_8bit) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order, prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) { is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : (group_size / 16); group_blocks = group_size == -1 ? -1 : (group_size / 16);
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel =
q_type, thread_m_blocks, th_config.thread_n / 16, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, th_config.thread_n / 16, th_config.thread_k / 16,
group_blocks, th_config.num_threads, is_zp_float); m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) { cudaFuncAttributes attr;
exec_cfg = {1, th_config}; cudaFuncGetAttributes(&attr, kernel);
break; int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
} else { int allow_count = min(device_max_reg_size / reg_size,
cudaFuncAttributes attr; max_shared_mem / (cache_size + 1536));
cudaFuncGetAttributes(&attr, kernel); if (thread_m_blocks == 1)
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1024));
allow_count = max(min(allow_count, 4), 1); allow_count = max(min(allow_count, 4), 1);
if (allow_count > count) { else
count = allow_count; allow_count = max(min(allow_count, 2), 1);
exec_cfg = {count, th_config};
}; if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
allow_count =
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
} }
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
} }
return exec_cfg; return exec_cfg;
} }
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* a_tmp, void* sorted_token_ids, void* expert_ids, void* perm, void* a_tmp, void* sorted_token_ids,
void* num_tokens_past_padded, void* topk_weights, void* expert_ids, void* num_tokens_past_padded,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, void* topk_weights, int moe_block_size, int num_experts,
int prob_m, int prob_n, int prob_k, void* workspace, int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
vllm::ScalarType const& q_type, bool has_bias, int prob_n, int prob_k, void* workspace,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups, vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
int group_size, int dev, cudaStream_t stream, int thread_k, vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
bool is_zp_float) { int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16); int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8; bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} }
} }
int num_bits = q_type.size_bits(); int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias; const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const float* a_s_ptr = (const float*)a_s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
// Set thread config // Set thread config
exec_config_t exec_cfg; exec_config_t exec_cfg;
thread_config_t thread_tfg; thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) { if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
exec_cfg = exec_config_t{1, thread_tfg}; if (blocks_per_sm == -1) blocks_per_sm = 1;
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n); " is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k); " is not divisible by thread_k = ", thread_k);
} else { } else {
// Auto config // Auto config
exec_cfg = determine_exec_config<scalar_t>( exec_cfg = determine_exec_config(
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
max_shared_mem); has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
} }
@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_k_blocks = thread_k / 16; int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16; int thread_n_blocks = thread_n / 16;
TORCH_CHECK( TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, prob_m, prob_n, prob_k, num_bits, group_size,
prob_n, prob_k, num_bits, group_size, has_act_order, has_act_order, is_k_full, has_zp, is_zp_float,
is_k_full, has_zp, is_zp_float, max_shared_mem), max_shared_mem, is_a_8bit),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k, ", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n, ", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", ", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits, prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order, ", group_size = ", group_size,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>( int sh_cache_size =
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
has_act_order, has_zp, group_blocks, num_threads, is_zp_float); prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>( kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on // clang-format on
} }
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float, int64_t thread_k, int64_t thread_n,
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); int64_t blocks_per_sm) {
int pack_factor = 32 / b_q_type.size_bits(); vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
int num_experts = b_q_weight.size(0);
if (moe_block_size != 8) { if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0, TORCH_CHECK(moe_block_size % 16 == 0,
@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as torch::Tensor a_scales;
// auto -1) auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
int thread_k = -1; auto options_fp32 =
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as torch::TensorOptions().dtype(at::kFloat).device(a.device());
// auto -1)
int thread_n = -1; if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// sms: number of SMs to use for the kernel // sms: number of SMs to use for the kernel
int sms = -1; int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c; torch::Tensor c;
if (c_or_none.has_value()) { if (c_or_none.has_value()) {
c = c_or_none.value(); c = c_or_none.value();
@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce // Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp; torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) { if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4 // max num of threadblocks is sms * 4
long max_c_tmp_size = min( long max_c_tmp_size = min(
@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format."); "the global_scale parameter must be passed for nvfp4 format.");
} }
@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_type == vllm::kU4 || b_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else { } else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, b_type == vllm::kS4 || b_type == vllm::kS8 ||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"float4_e2m1f when " "b_type must be uint4b8, uint8b128, int4, int8, "
"has_zp = False. Got = ", "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_q_type.str()); b_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), "scalar type of a_scales must be float");
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr, TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), "scalar type of global_scale must be the same with c");
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), if (a_type.size_bits() == 16) {
sorted_token_ids.data_ptr(), expert_ids.data_ptr(), TORCH_CHECK(
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), a.scalar_type() == c.scalar_type(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, "scalar type of a must be the same with c for 16 bit activation");
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
} }
MARLIN_NAMESPACE_NAME::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
return c; return c;
} }
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
} }

View File

@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor? b_bias_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? " "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none," "b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids," "Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, " "Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id," "bool mul_topk_weights, bool is_ep, int b_type_id,"
"int size_m, int size_n, int size_k," "int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add," "bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_fp32_reduce, bool is_zp_float,"
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "

View File

@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll #pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) { for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 = FType low16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]); C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 = FType high16 = MarlinScalarType2<FType>::float2num(
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) | uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16); (reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset = int sts_offset =

View File

@ -8,7 +8,7 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <iostream> #include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType; using marlin::MarlinScalarType2;
namespace allspark { namespace allspark {
@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) { for (int i = 0; i < n_mat; ++i) {
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]); sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
} }
C[idx] = ScalarType<FType>::float2num(sum); C[idx] = MarlinScalarType2<FType>::float2num(sum);
} }
template <typename FType> template <typename FType>

View File

@ -1 +1,2 @@
kernel_*.cu sm*_kernel_*.cu
kernel_selector.h

View File

@ -4,14 +4,16 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits> template <int const num_threads, int const num_bits, bool is_a_8bit>
__global__ void awq_marlin_repack_kernel( __global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor; constexpr int tile_n_ints = target_tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4; constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size; constexpr int stage_k_threads = target_tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int first_n_packed = first_n / pack_factor; int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe; int4* sh_ptr = sh + stage_size * pipe;
@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>( reinterpret_cast<int4 const*>(
@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor; int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor; int cur_n_pos = cur_n % pack_factor;
@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; if constexpr (is_a_8bit) {
int cur_elem = tc_row + i;
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; int packed_src_0 =
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * cur_elem]; sh_stride * cur_elem];
int packed_src_1 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * (cur_elem + 16)];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} else {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS) \ #define CALL_IF(NUM_BITS, IS_A_8BIT) \
else if (num_bits == NUM_BITS) { \ else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ IS_A_8BIT>, \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
b_q_weight_ptr, out_ptr, size_k, size_n); \ IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) { int64_t size_n, int64_t num_bits,
bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
if (false) { if (false) {
} }
CALL_IF(4) CALL_IF(4, false)
CALL_IF(8) CALL_IF(8, false)
CALL_IF(4, true)
CALL_IF(8, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", is_a_8bit = ", is_a_8bit);
} }
return out; return out;

View File

@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg);
} }
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70707070;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
// Note1: reverse indexing is intentional because weights are permuted
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
// `frag_b[1]` to fit the layout of tensorcore
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <>
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
int q, int32_t* frag_b) {
constexpr int repeated_zp = 0x08080808;
constexpr int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
int s = q & 0x08080808;
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
q >>= 4;
s = q & 0x08080808;
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <typename scalar_t2, vllm::ScalarTypeId s_type_id> template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
// Note: reverse indexing is intentional because weights are permuted // Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1); frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2); frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
};
// subtract zero point in quanted format and then dequant
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
template <>
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
int q, int32_t* frag_b, int zp) {
// INT4 with zp -> INT8
// see https://github.com/vllm-project/vllm/pull/24722
int repeated_zp = 0x01010101 * zp;
int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
true>(int q, __nv_fp8x4_e4m3* frag_b,
int zp) {
// INT4 with zp -> FP8
// see https://github.com/vllm-project/vllm/pull/24722
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
uint32_t u_zp1 = u_zp + 1;
uint32_t repeated_zp = 0x01010101 * u_zp;
uint32_t q0, s;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
u_q >>= 4;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
} }
#endif #endif

View File

@ -4,141 +4,292 @@ import glob
import itertools import itertools
import os import os
import subprocess import subprocess
import sys
import jinja2 import jinja2
FILE_HEAD = """ ARCHS = []
// auto generated by generate.py SUPPORT_FP8 = False
// clang-format off for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """
)
TEMPLATE = ( TEMPLATE = (
"template __global__ void Marlin<" "template __global__ void Marlin<"
"{{scalar_t}}, " "{{a_type_id}}, "
"{{w_type_id}}, " "{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, " "{{s_type_id}}, "
"{{threads}}, " "{{threads}}, "
"{{thread_m_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_n_blocks}}, " "{{thread_n_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{m_block_size_8}}, "
"{{stages}}, " "{{stages}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case QUANT_CONFIGS = [
# = -1 : channelwise quantization # AWQ-INT4
# > 0 : group_size=16*group_blocks {
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] "b_type": "kU4",
DTYPES = ["fp16", "bf16"] "thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# HQQ
{
"a_type": ["kFloat16"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [4],
"is_zp_float": True,
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
is_zp_float = quant_config.get("is_zp_float", False)
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = [] all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
for group_blocks, m_blocks, thread_configs in itertools.product( conditions = [
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS f"a_type == vllm::{a_type}",
): f"b_type == vllm::{b_type}",
# act order case only support gptq-int4 and gptq-int8 f"c_type == vllm::{c_type}",
if group_blocks == 0 and scalar_type not in [ f"s_type == vllm::{s_type}",
"vllm::kU4B8", f"threads == {config['threads']}",
"vllm::kU8B128", f"thread_m_blocks == {config['thread_m_blocks']}",
]: f"thread_n_blocks == {config['thread_n_blocks']}",
continue f"thread_k_blocks == {config['thread_k_blocks']}",
if thread_configs[2] == 256: f"m_block_size_8 == {config['m_block_size_8']}",
# for small batch (m_blocks == 1), we only need (128, 128, 256) f"group_blocks == {config['group_blocks']}",
# for large batch (m_blocks > 1), we only need (64, 256, 256) f"is_zp_float == {config['is_zp_float']}",
if m_blocks <= 1 and thread_configs[0] != 128: ]
continue conditions = " && ".join(conditions)
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128 if kernel_selector_str == FILE_HEAD_COMMENT:
# for fp8 kernel_selector_str += f"if ({conditions})\n kernel = "
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: else:
continue kernel_selector_str += f"else if ({conditions})\n kernel = "
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16 kernel_template2 = (
n_blocks = thread_configs[1] // 16 "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
threads = thread_configs[2] "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" kernel_selector_str += (
jinja2.Template(kernel_template2).render(
is_zp_float_list = [False] a_type_id=f"vllm::{a_type}.id()",
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: b_type_id=f"vllm::{b_type}.id()",
# HQQ (is_zp_float = true) only supports c_type_id=f"vllm::{c_type}.id()",
# 4bit quantization and fp16 s_type_id=f"vllm::{s_type}.id()",
is_zp_float_list.append(True) **config,
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
) )
+ "\n"
all_template_str_list.append(template_str) )
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__": if __name__ == "__main__":
remove_old_kernels() remove_old_kernels()

View File

@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm(
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float); has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size + 512 <= max_shared_mem; return cache_size <= max_shared_mem;
} }
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinFuncPtr get_marlin_kernel(
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ const vllm::ScalarType a_type, const vllm::ScalarType b_type,
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ const vllm::ScalarType c_type, const vllm::ScalarType s_type,
thread_n_blocks == THREAD_N_BLOCKS && \ int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
thread_k_blocks == THREAD_K_BLOCKS && \ bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
m_block_size_8 == M_BLOCK_SIZE_8 && \ int threads, bool is_zp_float) {
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ int num_bits = b_type.size_bits();
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4) #include "kernel_selector.h"
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, half>::value) {
if (false) {
}
FZP_GET_IF(vllm::kU4)
}
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
return kernel; return kernel;
} }
template <typename scalar_t> exec_config_t determine_exec_config(
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
int prob_n, int prob_k, int thread_m_blocks, const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
bool m_block_size_8, int num_bits, int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int group_size, bool has_act_order, int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool is_k_full, bool has_zp, bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
bool is_zp_float, int max_shared_mem,
int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs ? large_batch_thread_configs
@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) { is_zp_float, max_shared_mem - 512)) {
continue; continue;
} }
@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
group_blocks = group_size == -1 ? -1 : group_size / 16; group_blocks = group_size == -1 ? -1 : group_size / 16;
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel =
q_type, thread_m_blocks, th_config.thread_n / 16, get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, th_config.thread_n / 16, th_config.thread_k / 16,
group_blocks, th_config.num_threads, is_zp_float); m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue; if (kernel == MarlinDefault) continue;
@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
return exec_cfg; return exec_cfg;
} }
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type, bool has_bias, int lda, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups, bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k_init, int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add, int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) { bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} }
} }
int num_bits = q_type.size_bits(); int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A; const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B; const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias; const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s; const float* a_s_ptr = (const float*)a_s;
const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp; int4* a_tmp_ptr = (int4*)a_tmp;
int* locks = (int*)workspace; int* locks = (int*)workspace;
if (has_act_order) { if (has_act_order) {
@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
int max_par = 16; int max_par = 16;
if (prob_n <= 4096) max_par = 16 * 8; if (prob_n <= 4096) max_par = 16 * 8;
int max_shared_mem_new = max_shared_mem; int max_shared_mem_new = max_shared_mem;
@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_n = thread_n_init; int thread_n = thread_n_init;
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
int m_block_size_8 = prob_m_split <= 8; int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
// Set thread config // Set thread config
exec_config_t exec_cfg; exec_config_t exec_cfg;
@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
" is not divisible by thread_k = ", thread_k); " is not divisible by thread_k = ", thread_k);
} else { } else {
// Auto config // Auto config
exec_cfg = determine_exec_config<scalar_t>( exec_cfg = determine_exec_config(
q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
max_shared_mem, sms); is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg; thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n *
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
sms) {
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) {
thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg};
}
}
}
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
max_thread_m_blocks--; max_thread_m_blocks--;
continue; continue;
@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new); ", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
is_zp_float); num_threads, is_zp_float);
if (kernel == MarlinDefault) { if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>( kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
g_idx_ptr, num_groups, g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new); use_fp32_reduce, max_shared_mem_new);
// clang-format on // clang-format on
A_ptr += prob_m_split * (lda / 8); bool is_a_8bit = a_type.size_bits() == 8;
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
a_s_ptr += prob_m_split;
C_ptr += prob_m_split * (prob_n / 8); C_ptr += prob_m_split * (prob_n / 8);
rest_m -= prob_m_split; rest_m -= prob_m_split;
} }
@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none, std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) { bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
int pack_factor = 32 / b_q_type.size_bits();
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
// Verify A // Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
torch::Tensor a_scales;
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1) // auto -1)
int thread_k = -1; int thread_k = -1;
@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c; torch::Tensor c;
if (c_or_none.has_value()) { if (c_or_none.has_value()) {
c = c_or_none.value(); c = c_or_none.value();
@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce // Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp; torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce) { if (use_fp32_reduce) {
int max_m_block_size = (size_m + 16 - 1) / 16 * 16; int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
max_m_block_size = min(max_m_block_size, 64); max_m_block_size = min(max_m_block_size, 64);
@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale; torch::Tensor global_scale;
if (global_scale_or_none.has_value()) { if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value(); global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format."); "global_scale can only be used for nvfp4 format.");
} else { } else {
global_scale = torch::empty({0}, options); global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format."); "the global_scale parameter must be passed for nvfp4 format.");
} }
@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0; bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8, b_type == vllm::kU4 || b_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else { } else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, b_type == vllm::kS4 || b_type == vllm::kS8 ||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"float4_e2m1f when " "b_type must be uint4b8, uint8b128, int4, int8, "
"has_zp = False. Got = ", "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_q_type.str()); b_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), "scalar type of a_scales must be float");
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr, TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(), "scalar type of global_scale must be the same with c");
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, if (a_type.size_bits() == 16) {
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, TORCH_CHECK(
is_k_full, has_zp, num_groups, group_size, dev, a.scalar_type() == c.scalar_type(),
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, "scalar type of a must be the same with c for 16 bit activation");
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
} }
marlin::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
return c; return c;
} }

View File

@ -4,15 +4,18 @@
namespace marlin { namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm,
bool is_a_8bit>
__global__ void gptq_marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits; constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
int n_tiles = size_n / tile_n_size; constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4; constexpr int perm_size = target_tile_k_size / 4;
int4* sh_perm_ptr = sh; int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr; int4* sh_pipe_ptr = sh_perm_ptr;
@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel(
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
constexpr int tile_ints = tile_k_size / pack_factor; constexpr int tile_ints = target_tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_n_threads = target_tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4; int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr); int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel(
return; return;
} }
int first_n = n_tile_id * tile_n_size; int first_n = n_tile_id * target_tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * target_tile_k_size;
int first_k_packed = first_k / pack_factor; int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel(
} }
int tc_col = th_id / 4; int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2; int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9}; constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col; int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
constexpr int sh_stride = 64; constexpr int sh_stride = target_tile_n_size;
constexpr uint32_t mask = (1 << num_bits) - 1; constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t vals[8]; uint32_t vals[8];
if constexpr (has_perm) { if constexpr (has_perm) {
static_assert(!is_a_8bit);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i]; int k_idx = tc_row + tc_offsets[i];
@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel(
#pragma unroll #pragma unroll
for (int i = 0; i < tile_ints; i++) { for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; if constexpr (is_a_8bit) {
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; b1_vals[i] =
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
} else {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i]; int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
int cur_int = cur_elem / pack_factor; int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor; int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; if constexpr (is_a_8bit)
vals[4 + i] =
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
else
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
} }
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) { if constexpr (!is_a_8bit && num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res2 = 0; uint32_t res2 = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8); const int ii = is_a_8bit ? i : pack_idx[i];
res2 |= vals[4 + pack_idx[i]] << (i * 8); res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
} }
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \ HAS_PERM, IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM, IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits, bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
if (false) { if (false) {
} }
CALL_IF(4, false) CALL_IF(4, false, false)
CALL_IF(4, true) CALL_IF(4, true, false)
CALL_IF(8, false) CALL_IF(8, false, false)
CALL_IF(8, true) CALL_IF(8, true, false)
CALL_IF(4, false, true)
CALL_IF(8, false, true)
else { else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm); ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
} }
return out; return out;

View File

@ -11,17 +11,19 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \ const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \ const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the

View File

@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async // No support for async
#else #else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 4;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 8;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) { bool pred = true) {
const int BYTES = 16; const int BYTES = 16;

View File

@ -2,8 +2,10 @@
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include "marlin.cuh" #include "marlin.cuh"
#include "core/scalar_type.hpp"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin #define MARLIN_NAMESPACE_NAME marlin
@ -11,14 +13,16 @@
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t> template <long scalar_type_id>
class ScalarType {}; class MarlinScalarType {};
template <> template <>
class ScalarType<half> { class MarlinScalarType<vllm::kFloat16.id()> {
public: public:
using scalar_t = half; using scalar_t = half;
using scalar_t2 = half2; using scalar_t2 = half2;
using scalar_t4 = half2;
using scalar_32bit_t = half2;
// Matrix fragments for tensor core instructions; their precise layout is // Matrix fragments for tensor core instructions; their precise layout is
// documented here: // documented here:
@ -27,6 +31,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<half2, 4>; using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) { static __device__ float inline num2float(const half x) {
@ -44,18 +49,25 @@ class ScalarType<half> {
static __host__ __device__ half inline float2num(const float x) { static __host__ __device__ half inline float2num(const float x) {
return __float2half(x); return __float2half(x);
} }
static __host__ __device__ float2 inline num22float2(const half2 x) {
return __half22float2(x);
}
}; };
template <> template <>
class ScalarType<nv_bfloat16> { class MarlinScalarType<vllm::kBFloat16.id()> {
public: public:
using scalar_t = nv_bfloat16; using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162; using scalar_t2 = nv_bfloat162;
using scalar_t4 = nv_bfloat162;
using scalar_32bit_t = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>; using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<nv_bfloat162, 4>; using FragZP = Vec<nv_bfloat162, 4>;
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
@ -75,9 +87,63 @@ class ScalarType<nv_bfloat16> {
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x); return __float2bfloat16(x);
} }
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
return __bfloat1622float2(x);
}
#endif #endif
}; };
template <>
class MarlinScalarType<vllm::kFE4M3fn.id()> {
public:
using scalar_t = __nv_fp8_e4m3;
using scalar_t2 = __nv_fp8x2_e4m3;
using scalar_t4 = __nv_fp8x4_e4m3;
using scalar_32bit_t = __nv_fp8x4_e4m3;
using FragA = Vec<__nv_fp8x4_e4m3, 4>;
using FragB = Vec<__nv_fp8x4_e4m3, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<__nv_fp8x2_e4m3, 4>;
static __host__ __device__
float2 inline num22float2(const __nv_fp8x2_e4m3 x) {
return (float2)x;
}
};
template <>
class MarlinScalarType<vllm::kS8.id()> {
public:
using scalar_t = int8_t;
using scalar_t2 = int16_t;
using scalar_t4 = int32_t;
using scalar_32bit_t = int32_t;
using FragA = Vec<int32_t, 4>;
using FragB = Vec<int32_t, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<int16_t, 4>;
};
template <typename scalar_t>
class MarlinScalarType2 {};
template <>
class MarlinScalarType2<half> : public MarlinScalarType<vllm::kFloat16.id()> {};
template <>
class MarlinScalarType2<nv_bfloat16>
: public MarlinScalarType<vllm::kBFloat16.id()> {};
template <>
class MarlinScalarType2<__nv_fp8_e4m3>
: public MarlinScalarType<vllm::kFE4M3fn.id()> {};
template <>
class MarlinScalarType2<int8_t> : public MarlinScalarType<vllm::kS8.id()> {};
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
#endif #endif

View File

@ -0,0 +1,106 @@
#include "marlin.cuh"
#include "core/registration.h"
// for only non-zp format (like gptq)
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
// qweight: (size_k * size_n // 8,)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output) {
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
}
output[blockIdx.x * 32 + threadIdx.x] = new_val;
}
// for awq format only (with zp and with awq weight layout)
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
// AWQ qweight: (size_k, size_n // 8)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output,
// AWQ zeros: (size_k // group_size, size_n // 8)
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
int32_t group_size) {
int32_t val =
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
int32_t zero =
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
blockIdx.y];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
int32_t single_zero = zero & 0xF;
single_val =
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
zero >>= 4;
}
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
}
torch::Tensor marlin_int4_fp8_preprocess(
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
bool inplace) {
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
if (!qzeros_or_none.has_value()) {
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
"qweight.numel() * 8 % 256 != 0");
int blocks = qweight.numel() * 8 / 256;
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
} else {
int32_t size_k = qweight.size(0);
int32_t size_n = qweight.size(1) * 8;
torch::Tensor qzeros = qzeros_or_none.value();
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
"qzeros is not on the same device with qweight");
int32_t group_size = qweight.size(0) / qzeros.size(0);
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
"qweight.size(1) != qzeros.size(1)");
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
"qweight.size(0) % qzeros.size(0) != 0");
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
dim3 blocks(size_k / 32, size_n / 8);
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
}
return output;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
}

File diff suppressed because it is too large Load Diff

View File

@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ. // gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def( ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none," "Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin repack from GPTQ. // gptq_marlin repack from GPTQ.
ops.def( ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); "SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ. // awq_marlin repack from AWQ.
ops.def( ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"); "SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// preprocess W-int4A-fp8 weight for marlin kernel
ops.def(
"marlin_int4_fp8_preprocess(Tensor qweight, "
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM // CUTLASS w4a8 GEMM

View File

@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod] - [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]

View File

@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment from vllm.distributed.parallel_state import init_distributed_environment
@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
MOE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
FUSED_MOE_MNK_FACTORS = [ FUSED_MOE_MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 2048, 128), (1, 2048, 128),
@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases():
m_list = [1, 123, 666] m_list = [1, 123, 666]
n_list = [128, 1024] n_list = [128, 1024]
k_list = [256, 2048] k_list = [256, 2048]
e_list = [4, 12] e_list = [5, 12]
topk_list = [2, 3] topk_list = [2, 3]
ep_size_list = [1, 4] ep_size_list = [1, 4]
dtype_list = [torch.bfloat16]
group_size_list = [-1, 32, 128]
act_order_list = [True, False] act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
scalar_types.float8_e4m3fn,
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
is_k_full_list = [True, False] is_k_full_list = [True, False]
all_combinations = itertools.product( all_combinations = itertools.product(
MOE_MARLIN_QUANT_TEST_CONFIGS,
m_list, m_list,
n_list, n_list,
k_list, k_list,
e_list, e_list,
topk_list, topk_list,
ep_size_list, ep_size_list,
dtype_list,
group_size_list,
act_order_list, act_order_list,
quant_type_list,
is_k_full_list, is_k_full_list,
) )
def is_invalid( def is_invalid(
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full a_type,
b_type,
c_type,
group_blocks,
m,
n,
k,
e,
topk,
ep_size,
act_order,
is_k_full,
): ):
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
return False if group_size > 0 and k % group_size != 0:
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False return False
# Filter act_order if act_order and group_size in [-1, k, n]:
if act_order: return False
if group_size in (-1, k, n): if group_size in [k, n]:
return False return False
if quant_type not in [scalar_types.uint4b8]: if not act_order and is_k_full:
return False
elif not is_k_full:
return False return False
return True return a_type.size_bits < 16 or a_type is c_type
cases = [] cases = []
for case in all_combinations: for case in all_combinations:
if is_invalid(*case): quant_test_config, m, n, k, _, _, _, act_order, *_ = case
cases.append(case) if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (m, n, k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases return cases
@ -571,6 +640,7 @@ class MarlinMoEWeightData:
qweight: torch.Tensor qweight: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
global_scale: torch.Tensor | None global_scale: torch.Tensor | None
a_scales_factor: torch.Tensor | None
g_idx: torch.Tensor | None g_idx: torch.Tensor | None
zeros: torch.Tensor | None zeros: torch.Tensor | None
sort_indices: torch.Tensor | None sort_indices: torch.Tensor | None
@ -583,11 +653,20 @@ class MarlinMoEWeightData:
group_size: int, group_size: int,
act_order: bool | None = None, act_order: bool | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
input_type: ScalarType = None,
) -> "MarlinMoEWeightData": ) -> "MarlinMoEWeightData":
assert w.ndim == 3 assert w.ndim == 3
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
k = w.shape[-1] k = w.shape[-1]
if input_type == scalar_types.int8:
input_dtype = torch.int8
elif input_type == scalar_types.float8_e4m3fn:
input_dtype = torch.float8_e4m3fn
else:
input_dtype = w.dtype
w_ref_l: list[torch.Tensor] = [] w_ref_l: list[torch.Tensor] = []
qweight_l: list[torch.Tensor] = [] qweight_l: list[torch.Tensor] = []
scales_l: list[torch.Tensor] = [] scales_l: list[torch.Tensor] = []
@ -601,11 +680,13 @@ class MarlinMoEWeightData:
if quant_type == scalar_types.float4_e2m1f: if quant_type == scalar_types.float4_e2m1f:
if group_size == 16: if group_size == 16:
w_ref, qweight, scales, global_scale = ( w_ref, qweight, scales, global_scale = (
rand_marlin_weight_nvfp4_like(w[i], group_size) rand_marlin_weight_nvfp4_like(
w[i], group_size, input_dtype=input_dtype
)
) )
else: else:
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like( w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
w[i], group_size w[i], group_size, input_dtype=input_dtype
) )
global_scale = None global_scale = None
@ -615,13 +696,18 @@ class MarlinMoEWeightData:
if global_scale is not None: if global_scale is not None:
global_scale_l.append(global_scale) global_scale_l.append(global_scale)
elif quant_type == scalar_types.float8_e4m3fn: elif quant_type == scalar_types.float8_e4m3fn:
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size) w_ref, qweight, scales = marlin_quant_fp8_torch(
w[i], group_size, input_dtype=input_dtype
)
w_ref_l.append(w_ref.T) w_ref_l.append(w_ref.T)
qweight_l.append(qweight) qweight_l.append(qweight)
scales_l.append(scales) scales_l.append(scales)
elif has_zp: elif has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize( w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size w[i].transpose(1, 0),
quant_type,
group_size,
input_dtype=input_dtype,
) )
w_ref_l.append(w_ref.T) w_ref_l.append(w_ref.T)
@ -631,7 +717,12 @@ class MarlinMoEWeightData:
else: else:
test_perm = torch.randperm(k) test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm w[i].transpose(1, 0),
quant_type,
group_size,
act_order,
test_perm,
input_dtype=input_dtype,
) )
w_ref_l.append(w_ref.T) w_ref_l.append(w_ref.T)
@ -652,11 +743,18 @@ class MarlinMoEWeightData:
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
marlin_bias = stack_and_dev(bias_l) if bias_l else None marlin_bias = stack_and_dev(bias_l) if bias_l else None
a_scales_factor = None
if input_type == scalar_types.int8 and group_size != -1:
a_scales_factor = 1 / 4096 * scales.max().float()
scales = scales / scales.max() * 4096
scales = scales.round().to(torch.int16).view(w.dtype)
return MarlinMoEWeightData( return MarlinMoEWeightData(
w_ref=w_ref, w_ref=w_ref,
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
global_scale=global_scale, global_scale=global_scale,
a_scales_factor=a_scales_factor,
g_idx=g_idx, g_idx=g_idx,
zeros=zeros, zeros=zeros,
sort_indices=sort_indices, sort_indices=sort_indices,
@ -666,28 +764,47 @@ class MarlinMoEWeightData:
@pytest.mark.flaky(reruns=2) @pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), (
"a_type, b_type, c_type, group_blocks,"
"m, n, k, e, topk, ep_size, act_order, is_k_full"
),
marlin_moe_generate_valid_test_cases(), marlin_moe_generate_valid_test_cases(),
) )
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, a_type,
n: int, b_type,
k: int, c_type,
e: int, group_blocks,
topk: int, m,
ep_size: int, n,
dtype: torch.dtype, k,
group_size: int, e,
act_order: bool, topk,
quant_type: ScalarType, ep_size,
is_k_full: bool, act_order,
is_k_full,
): ):
torch.cuda.manual_seed(0) torch.cuda.manual_seed(1)
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if c_type == scalar_types.float16:
dtype = torch.float16
elif c_type == scalar_types.bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("unsupported c_type")
if a_type == scalar_types.int8:
a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
if ep_size > 1: if ep_size > 1:
local_e = e // ep_size local_e = e // ep_size
@ -700,11 +817,19 @@ def test_fused_marlin_moe(
e_map = None e_map = None
w1_data = MarlinMoEWeightData.make( w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order w=w1,
quant_type=b_type,
group_size=group_size,
act_order=act_order,
input_type=a_type,
) )
w2_data = MarlinMoEWeightData.make( w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order w=w2,
quant_type=b_type,
group_size=group_size,
act_order=act_order,
input_type=a_type,
) )
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
@ -712,8 +837,18 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
torch_output = torch_moe( score = torch.softmax(score, dim=-1, dtype=torch.float32)
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map topk_weight, topk_ids = torch.topk(score, topk)
torch_output = torch_experts(
a,
w1_data.w_ref,
w2_data.w_ref,
topk_weight=topk_weight,
topk_ids=topk_ids,
global_num_experts=e,
expert_map=e_map,
quant_dtype=a_dtype,
per_act_token_quant=True,
) )
marlin_output = fused_marlin_moe( marlin_output = fused_marlin_moe(
@ -733,15 +868,18 @@ def test_fused_marlin_moe(
global_scale2=w2_data.global_scale, global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx, g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx, g_idx2=w2_data.g_idx,
input_global_scale1=w1_data.a_scales_factor,
input_global_scale2=w2_data.a_scales_factor,
sort_indices1=w1_data.sort_indices, sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices, sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros, w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros, w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id, input_dtype=a_dtype,
quant_type_id=b_type.id,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
@pytest.mark.flaky(reruns=2) @pytest.mark.flaky(reruns=2)

View File

@ -5,6 +5,8 @@
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
""" """
import itertools
import pytest import pytest
import torch import torch
@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
) )
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
marlin_make_empty_g_idx, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
query_marlin_supported_quant_types, query_marlin_supported_quant_types,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
rand_marlin_weight_mxfp4_like, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like, rand_marlin_weight_nvfp4_like,
) )
@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights, quantize_weights,
sort_weights, sort_weights,
) )
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True] ACT_ORDER_OPTS = [False, True]
@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512]
HQQ_SUPPORTED_GROUP_SIZES = [64] HQQ_SUPPORTED_GROUP_SIZES = [64]
MARLIN_REPACK_NK_FACTORS = [
(4, 8),
(7, 5),
(13, 11),
]
MNK_FACTORS = [ MNK_FACTORS = [
(1, 1, 1), (1, 1, 1),
(1, 4, 8), (1, 4, 8),
@ -74,6 +84,64 @@ MNK_FACTORS = [
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
def compute_max_diff(output, output_ref): def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean( return torch.mean(torch.abs(output - output_ref)) / torch.mean(
@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda") return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
torch_res = torch.where(
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
)
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
group_size = 128
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qzeros_unpacked = torch.randint(
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
torch_res = qweight_unpacked - repeated_zp
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif( @pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"), not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.", reason="Marlin is not supported on this GPU type.",
@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_gptq_marlin_repack( def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
): ):
m_factor, n_factor, k_factor = mnk_factors n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
group_size = 128
# Filter act_order # Filter act_order
if act_order: if act_order:
@ -109,6 +230,8 @@ def test_gptq_marlin_repack(
return return
if group_size == size_k: if group_size == size_k:
return return
if is_a_8bit:
return
# Normalize group_size # Normalize group_size
if group_size == -1: if group_size == -1:
@ -133,23 +256,19 @@ def test_gptq_marlin_repack(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format # Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits) weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights( marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
) )
opcheck( opcheck(
torch.ops._C.gptq_marlin_repack, torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits), (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
) )
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack( marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq, q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
sort_indices,
size_k,
size_n,
quant_type.size_bits,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -163,18 +282,15 @@ def test_gptq_marlin_repack(
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
m_factor, n_factor, k_factor = mnk_factors n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
# Normalize group_size group_size = 128
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Create input # Create input
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format # Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits) weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights( marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
) )
opcheck( opcheck(
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits) torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
) )
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack( marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq, q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
size_k,
size_n,
quant_type.size_bits,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
def marlin_generate_valid_test_cases():
all_combinations = itertools.product(
DENSE_MARLIN_QUANT_TEST_CONFIGS,
MNK_FACTORS,
MARLIN_N_CHUNKS,
MARLIN_K_CHUNKS,
ACT_ORDER_OPTS,
K_FULL_OPTS,
USE_ATOMIC_ADD_OPTS,
USE_FP32_REDUCE_OPTS,
)
def is_invalid(
a_type,
b_type,
c_type,
group_blocks,
size_m,
size_n,
size_k,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
if use_atomic_add:
if use_fp32_reduce:
return False
if (
c_type == scalar_types.bfloat16
and torch.cuda.get_device_capability()[0] < 9
):
return False
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and size_k % group_size != 0:
return False
if act_order and group_size in [-1, size_k]:
return False
if group_size == size_k:
return False
if not act_order and is_k_full:
return False
return a_type.size_bits < 16 or a_type is c_type
cases = []
for case in all_combinations:
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
size_m = mnk_factors[0]
size_n = mnk_factors[1] * n_chunk
size_k = mnk_factors[2] * k_chunk
if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16, scalar_types.bfloat16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases
@pytest.mark.skipif( @pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"), not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.", reason="Marlin is not supported on this GPU type.",
) )
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES) (
"a_type, b_type, c_type, group_blocks,"
"size_m, size_n, size_k, act_order, is_k_full,"
"use_atomic_add, use_fp32_reduce"
),
marlin_generate_valid_test_cases(),
) )
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm( def test_gptq_marlin_gemm(
k_chunk, a_type,
n_chunk, b_type,
quant_type, c_type,
group_size, group_blocks,
mnk_factors, size_m,
size_n,
size_k,
act_order, act_order,
is_k_full, is_k_full,
use_atomic_add, use_atomic_add,
use_fp32_reduce, use_fp32_reduce,
dtype,
): ):
m_factor, n_factor, k_factor = mnk_factors has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order: if c_type == scalar_types.float16:
if group_size == -1: dtype = torch.float16
return elif c_type == scalar_types.bfloat16:
if group_size == size_k: dtype = torch.bfloat16
return else:
if has_zp: raise RuntimeError("unsupported c_type")
return
if size_k % group_size != 0: if a_type == scalar_types.int8:
return a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
a_input = rand_data((size_m, size_k), dtype) a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype) b_weight = rand_data((size_k, size_n), dtype=dtype)
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32] or act_order:
return
if group_size == 32 and dtype == torch.float16:
return
if b_type == scalar_types.float4_e2m1f:
if group_size == 16: if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like( w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
b_weight.T, group_size b_weight.T, group_size, input_dtype=a_dtype
) )
else: else:
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like( w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
b_weight.T, group_size b_weight.T, group_size, input_dtype=a_dtype
) )
marlin_s2 = None marlin_s2 = None
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_zp = None marlin_zp = None
elif quant_type == scalar_types.float8_e4m3fn: elif b_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]: w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
return b_weight.T, group_size, input_dtype=a_dtype
if act_order: )
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_zp = None marlin_zp = None
marlin_s2 = None marlin_s2 = None
elif has_zp: elif has_zp:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size b_weight, b_type, group_size, input_dtype=a_dtype
) )
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_s2 = None marlin_s2 = None
else: else:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order b_weight, b_type, group_size, act_order, input_dtype=a_dtype
) )
marlin_zp = None marlin_zp = None
marlin_s2 = None marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device) workspace = marlin_make_workspace_new(w_ref.device)
opcheck( if a_type == scalar_types.int8:
torch.ops._C.gptq_marlin_gemm, a_input, a_scales = per_token_quant_int8(a_input)
( a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
a_input, a_input_ref = a_input_ref.to(dtype)
None,
marlin_q_w, if group_size != -1:
None, a_scales = a_scales / 4096 * marlin_s.max()
marlin_s, a_scales = a_scales.float()
marlin_s2, marlin_s = marlin_s / marlin_s.max() * 4096
marlin_zp, marlin_s = marlin_s.round().to(torch.int16).view(dtype)
g_idx, elif a_type == scalar_types.float8_e4m3fn:
sort_indices, a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
workspace, a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
quant_type.id, a_input_ref = a_input_ref.to(dtype)
a_input.shape[0], else:
b_weight.shape[1], assert a_type.size_bits == 16
a_input.shape[1], a_input_ref = a_input
is_k_full, a_scales = None
use_atomic_add,
use_fp32_reduce, output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
False,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a_input, a_input,
None, output,
marlin_q_w, marlin_q_w,
None, None,
marlin_s, marlin_s,
a_scales,
marlin_s2, marlin_s2,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
workspace, workspace,
quant_type, b_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
@ -352,12 +527,9 @@ def test_gptq_marlin_gemm(
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False, is_zp_float=False,
) )
output_ref = torch.matmul(a_input, w_ref) output_ref = torch.matmul(a_input_ref, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04 assert max_diff < 0.04
@ -507,6 +679,7 @@ def test_hqq_marlin_gemm(
None, None,
marlin_s, marlin_s,
None, None,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
g_idx_sort_indices, g_idx_sort_indices,
@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input():
None, None,
marlin_s, marlin_s,
None, None,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m):
marlin_bias, marlin_bias,
marlin_s, marlin_s,
None, None,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,

View File

@ -846,6 +846,13 @@ def torch_experts(
or (expert_map is not None and global_num_experts == expert_map.shape[0]) or (expert_map is not None and global_num_experts == expert_map.shape[0])
) )
if quant_dtype in [torch.float16, torch.bfloat16]:
quant_dtype = None
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
if quant_input_only:
assert a1_scale is None and a2_scale is None
assert per_act_token_quant
M, K = a.shape M, K = a.shape
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
@ -863,6 +870,9 @@ def torch_experts(
a, a1_scale, quant_dtype, per_act_token_quant, block_shape a, a1_scale, quant_dtype, per_act_token_quant, block_shape
) )
if quant_input_only:
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
num_experts = w1.shape[0] num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
@ -882,6 +892,14 @@ def torch_experts(
out[mask] = tmp2 @ w2[i].transpose(0, 1) out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None: if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
elif quant_input_only:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
tmp2, tmp2_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant
)
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None: elif block_shape is not None:
# block quantized # block quantized
assert ( assert (

View File

@ -554,6 +554,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_bias: torch.Tensor | None, b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None, global_scale: torch.Tensor | None,
b_zeros: torch.Tensor | None, b_zeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
@ -568,7 +569,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce: bool = False, use_fp32_reduce: bool = False,
is_zp_float: bool = False, is_zp_float: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) dtype = a.dtype
if dtype not in [torch.half, torch.bfloat16]:
dtype = b_scales.dtype
return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
@register_fake("_C::awq_dequantize") @register_fake("_C::awq_dequantize")
def _awq_dequantize_fake( def _awq_dequantize_fake(
@ -1167,8 +1171,11 @@ def gptq_marlin_repack(
size_k: int, size_k: int,
size_n: int, size_n: int,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) return torch.ops._C.gptq_marlin_repack(
b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
)
if hasattr(torch.ops._C, "gptq_marlin_repack"): if hasattr(torch.ops._C, "gptq_marlin_repack"):
@ -1180,6 +1187,7 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
size_k: torch.SymInt, size_k: torch.SymInt,
size_n: torch.SymInt, size_n: torch.SymInt,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
pack_factor = 32 // num_bits pack_factor = 32 // num_bits
marlin_tile_size = 16 marlin_tile_size = 16
@ -1192,9 +1200,15 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
# awq_marlin # awq_marlin
def awq_marlin_repack( def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int b_q_weight: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) return torch.ops._C.awq_marlin_repack(
b_q_weight, size_k, size_n, num_bits, is_a_8bit
)
if hasattr(torch.ops._C, "awq_marlin_repack"): if hasattr(torch.ops._C, "awq_marlin_repack"):
@ -1205,6 +1219,7 @@ if hasattr(torch.ops._C, "awq_marlin_repack"):
size_k: torch.SymInt, size_k: torch.SymInt,
size_n: torch.SymInt, size_n: torch.SymInt,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
pack_factor = 32 // num_bits pack_factor = 32 // num_bits
marlin_tile_size = 16 marlin_tile_size = 16
@ -1221,6 +1236,7 @@ def gptq_marlin_moe_repack(
size_k: int, size_k: int,
size_n: int, size_n: int,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_experts = b_q_weight.shape[0] num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0 assert size_k % 16 == 0
@ -1231,7 +1247,7 @@ def gptq_marlin_moe_repack(
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops._C.gptq_marlin_repack( output[e] = torch.ops._C.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
) )
return output return output
@ -1242,6 +1258,7 @@ def awq_marlin_moe_repack(
size_k: int, size_k: int,
size_n: int, size_n: int,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_experts = b_q_weight.shape[0] num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0 assert size_k % 16 == 0
@ -1252,17 +1269,26 @@ def awq_marlin_moe_repack(
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack( output[e] = torch.ops._C.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
) )
return output return output
def marlin_int4_fp8_preprocess(
qweight: torch.Tensor,
qzeros_or_none: torch.Tensor | None = None,
inplace: bool = False,
):
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
def gptq_marlin_gemm( def gptq_marlin_gemm(
a: torch.Tensor, a: torch.Tensor,
c: torch.Tensor | None, c: torch.Tensor | None,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_bias: torch.Tensor | None, b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None, global_scale: torch.Tensor | None,
b_zeros: torch.Tensor | None, b_zeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
@ -1283,6 +1309,7 @@ def gptq_marlin_gemm(
b_q_weight, b_q_weight,
b_bias, b_bias,
b_scales, b_scales,
a_scales,
global_scale, global_scale,
b_zeros, b_zeros,
g_idx, g_idx,
@ -1600,7 +1627,7 @@ def allspark_repack_weight(
if use asymmetric quantization, has_zp = True. if use asymmetric quantization, has_zp = True.
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] :
rearranged weight, scale, and optionally zero_point. rearranged weight, scale, and optionally zero_point.
""" """
K = qweight.shape[0] K = qweight.shape[0]
@ -1683,7 +1710,7 @@ def scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored). symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
""" """
output = torch.empty_like(input, dtype=torch.int8) output = torch.empty_like(input, dtype=torch.int8)
if scale is not None: if scale is not None:
@ -2004,6 +2031,7 @@ def moe_wna16_marlin_gemm(
b_qweight: torch.Tensor, b_qweight: torch.Tensor,
b_bias: torch.Tensor | None, b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None, global_scale: torch.Tensor | None,
b_qzeros: torch.Tensor | None, b_qzeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
@ -2025,6 +2053,9 @@ def moe_wna16_marlin_gemm(
use_atomic_add: bool, use_atomic_add: bool,
use_fp32_reduce: bool, use_fp32_reduce: bool,
is_zp_float: bool, is_zp_float: bool,
thread_k: int = -1,
thread_n: int = -1,
blocks_per_sm: int = -1,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm( return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, input,
@ -2032,6 +2063,7 @@ def moe_wna16_marlin_gemm(
b_qweight, b_qweight,
b_bias, b_bias,
b_scales, b_scales,
a_scales,
global_scale, global_scale,
b_qzeros, b_qzeros,
g_idx, g_idx,
@ -2053,6 +2085,9 @@ def moe_wna16_marlin_gemm(
use_atomic_add, use_atomic_add,
use_fp32_reduce, use_fp32_reduce,
is_zp_float, is_zp_float,
thread_k,
thread_n,
blocks_per_sm,
) )
@ -2088,7 +2123,10 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
input: torch.Tensor, input: torch.Tensor,
output: torch.Tensor | None, output: torch.Tensor | None,
b_qweight: torch.Tensor, b_qweight: torch.Tensor,
b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None,
b_qzeros: torch.Tensor | None, b_qzeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
perm: torch.Tensor | None, perm: torch.Tensor | None,
@ -2109,7 +2147,7 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
use_atomic_add: bool, use_atomic_add: bool,
use_fp32_reduce: bool, use_fp32_reduce: bool,
is_zp_float: bool, is_zp_float: bool,
) -> torch.Tensor: ):
return torch.empty( return torch.empty(
(size_m * top_k, size_n), dtype=input.dtype, device=input.device (size_m * top_k, size_n), dtype=input.dtype, device=input.device
) )
@ -2583,7 +2621,7 @@ def onednn_scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored). symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
""" """
output = torch.empty_like(input, dtype=torch.int8) output = torch.empty_like(input, dtype=torch.int8)
token_num = input.numel() // input.shape[-1] token_num = input.numel() // input.shape[-1]

View File

@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
@ -1122,6 +1123,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
os.environ.get("VLLM_MXFP4_USE_MARLIN", None) os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
), ),
# The activation dtype for marlin kernel
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
),
# Whether to turn on the outlines cache for V1 # Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in # This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users. # an environment with potentially malicious users.

View File

@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_intermediate_size, marlin_moe_intermediate_size,
maybe_warn_marlin_atomic_add, marlin_quant_input,
) )
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
@ -65,6 +65,8 @@ def _fused_marlin_moe(
activation_func: Callable[ activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None [str, torch.Tensor, torch.Tensor], None
] = default_activation_func, ] = default_activation_func,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None,
@ -77,6 +79,7 @@ def _fused_marlin_moe(
intermediate_cache13: torch.Tensor | None = None, intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None, intermediate_cache2: torch.Tensor | None = None,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
is_k_full: bool = True, is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
assert hidden_states.ndim == 2 assert hidden_states.ndim == 2
@ -106,18 +109,22 @@ def _fused_marlin_moe(
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N)) intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) a_scales1 = None
use_atomic_add = ( gate_up_input = hidden_states
hidden_states.dtype == torch.half if input_dtype == torch.int8:
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
) if input_global_scale1 is not None:
a_scales1 = a_scales1 * input_global_scale1
elif input_dtype == torch.float8_e4m3fn:
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
intermediate_cache1 = ops.moe_wna16_marlin_gemm( intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states, gate_up_input,
intermediate_cache1, intermediate_cache1,
w1, w1,
bias1, bias1,
w1_scale, w1_scale,
a_scales1,
global_scale1, global_scale1,
w1_zeros, w1_zeros,
g_idx1, g_idx1,
@ -136,7 +143,7 @@ def _fused_marlin_moe(
size_n=2 * N, size_n=2 * N,
size_k=K, size_k=K,
is_k_full=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=False,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
) )
@ -151,12 +158,25 @@ def _fused_marlin_moe(
if expert_map is not None: if expert_map is not None:
output.zero_() output.zero_()
a_scales2 = None
if input_dtype == torch.int8:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
if input_global_scale2 is not None:
a_scales2 = a_scales2 * input_global_scale2
elif input_dtype == torch.float8_e4m3fn:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
output = ops.moe_wna16_marlin_gemm( output = ops.moe_wna16_marlin_gemm(
intermediate_cache2, intermediate_cache2,
output, output,
w2, w2,
bias2, bias2,
w2_scale, w2_scale,
a_scales2,
global_scale2, global_scale2,
w2_zeros, w2_zeros,
g_idx2, g_idx2,
@ -175,7 +195,7 @@ def _fused_marlin_moe(
size_n=K, size_n=K,
size_k=N, size_k=N,
is_k_full=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=False,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
) )
@ -203,6 +223,8 @@ def fused_marlin_moe(
] = default_activation_func, ] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None,
@ -216,6 +238,7 @@ def fused_marlin_moe(
intermediate_cache2: torch.Tensor | None = None, intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True, is_k_full: bool = True,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -287,6 +310,9 @@ def fused_marlin_moe(
if M * topk / E / block_size_m < 0.9: if M * topk / E / block_size_m < 0.9:
break break
if input_dtype is not None and input_dtype.itemsize == 1:
block_size_m = max(block_size_m, 16)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
@ -313,6 +339,8 @@ def fused_marlin_moe(
num_tokens_post_padded=num_tokens_post_padded, num_tokens_post_padded=num_tokens_post_padded,
activation=activation, activation=activation,
activation_func=activation_func, activation_func=activation_func,
input_global_scale1=input_global_scale1,
input_global_scale2=input_global_scale2,
global_scale1=global_scale1, global_scale1=global_scale1,
global_scale2=global_scale2, global_scale2=global_scale2,
g_idx1=g_idx1, g_idx1=g_idx1,
@ -325,6 +353,7 @@ def fused_marlin_moe(
intermediate_cache13=intermediate_cache13, intermediate_cache13=intermediate_cache13,
intermediate_cache2=intermediate_cache2, intermediate_cache2=intermediate_cache2,
output=None, output=None,
input_dtype=input_dtype,
is_k_full=is_k_full, is_k_full=is_k_full,
).view(-1, topk, K) ).view(-1, topk, K)

View File

@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.awq_marlin import ( from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig, AWQMarlinConfig,
AWQMarlinLinearMethod, AWQMarlinLinearMethod,
AWQMoEMethod, AWQMarlinMoEMethod,
) )
quant_args_marlin = AWQMarlinConfig( quant_args_marlin = AWQMarlinConfig(
@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if use_marlin: if use_marlin:
return AWQMoEMethod(quant_args_marlin, layer.moe_config) return AWQMarlinMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
config = { config = {

View File

@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig):
return AWQLinearMethod(self) return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import. # Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer from .utils.marlin_utils import check_moe_marlin_supports_layer
@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig):
awq_marlin_config = AWQMarlinConfig.from_config( awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict marlin_compatible_config_dict
) )
return AWQMoEMethod(awq_marlin_config, layer.moe_config) return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config)
return None return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):

View File

@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_marlin_supported,
check_marlin_supports_layer, check_marlin_supports_layer,
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_empty_g_idx, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig):
# num_bits -> type # num_bits -> type
TYPE_MAP = { TYPE_MAP = {
4: scalar_types.uint4, 4: scalar_types.uint4,
8: scalar_types.uint8,
} }
def __init__( def __init__(
@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig):
return AWQConfig.from_config(self.full_config).get_quant_method( return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
return AWQMarlinLinearMethod(self) quant_method = AWQMarlinLinearMethod(self)
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method( return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
return AWQMoEMethod(self, layer.moe_config) moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
return None return None
@classmethod @classmethod
@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQMarlinConfig) -> None: def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.quant_type = scalar_types.uint4
self.input_dtype = None
def create_weights( def create_weights(
self, self,
@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
) )
num_groups = input_size_per_partition // group_size num_groups = input_size_per_partition // group_size
layer.num_groups = num_groups
qzeros = PackedvLLMParameter( qzeros = PackedvLLMParameter(
data=torch.empty( data=torch.empty(
@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Allocate marlin workspace # Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device) layer.workspace = marlin_make_workspace_new(device)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
layer.scales.data = layer.scales.data * 512
# Repack weights from AWQ format to marlin format. # Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack( marlin_qweight = ops.awq_marlin_repack(
layer.qweight, layer.qweight,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "qweight", marlin_qweight) replace_parameter(layer, "qweight", marlin_qweight)
@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups > 1:
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
marlin_scales
)
layer.register_parameter(
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
)
replace_parameter(layer, "scales", marlin_scales) replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format. # Permute zero-points from AWQ format to marlin format.
@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups, size_k=layer.num_groups,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "qzeros", marlin_zp) replace_parameter(layer, "qzeros", marlin_zp)
@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type=self.quant_config.quant_type, quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias, bias=bias,
input_dtype=self.input_dtype,
) )
class AWQMoEMethod(FusedMoEMethodBase): class AWQMarlinMoEMethod(FusedMoEMethodBase):
def __init__( def __init__(
self, self,
quant_config: AWQMarlinConfig, quant_config: AWQMarlinConfig,
@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.weight_bits != 4: if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.") raise ValueError("AWQMarlinMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4 self.quant_type = scalar_types.uint4
self.input_dtype = None
self.use_marlin = True self.use_marlin = True
def create_weights( def create_weights(
@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
layer.input_dtype = self.input_dtype
extra_weight_attrs.update( extra_weight_attrs.update(
{ {
"is_transposed": True, "is_transposed": True,
@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_groups_w13 = hidden_size // self.quant_config.group_size num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0] num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device device = layer.w13_qweight.device
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(
layer.w13_qweight.view(-1, layer.w13_qweight.size(2)),
layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)),
inplace=True,
)
ops.marlin_int4_fp8_preprocess(
layer.w2_qweight.view(-1, layer.w2_qweight.size(2)),
layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)),
inplace=True,
)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
layer.w13_g_idx_sort_indices = torch.nn.Parameter( layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device), torch.empty((num_experts, 0), dtype=torch.int32, device=device),
@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qweight.shape[1], size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qweight.shape[1], size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition, size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2], size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales) replace_parameter(layer, "w13_scales", marlin_w13_scales)
@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition, size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points( marlin_w13_zp = moe_awq_to_marlin_zero_points(
@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qzeros.shape[1], size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_qzeros", marlin_w13_zp) replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qzeros.shape[1], size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qzeros", marlin_w2_zp) replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase):
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace, workspace=layer.workspace,
input_dtype=self.input_dtype,
) )

View File

@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention): if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) return CompressedTensorsMoEMethod.get_moe_method(
self, layer, layer_name=prefix
)
return None return None
def _add_fused_moe_to_target_scheme_map(self): def _add_fused_moe_to_target_scheme_map(self):
@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: QuantizationArgs, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs, input_quant: QuantizationArgs,
format: str | None = None, format: str | None = None,
layer_name: str | None = None,
) -> "CompressedTensorsScheme": ) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format # use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format format = format if format is not None else self.quant_format
@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig):
symmetric=weight_quant.symmetric, symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size, group_size=weight_quant.group_size,
actorder=weight_quant.actorder, actorder=weight_quant.actorder,
layer_name=layer_name,
) )
act_quant_format = is_activation_quantization_format(format) act_quant_format = is_activation_quantization_format(format)
@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig):
else: else:
# Find the quant_scheme # Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant, input_quant=input_quant, format=format weight_quant=weight_quant,
input_quant=input_quant,
format=format,
layer_name=layer_name,
) )
# Raise error if device does not support the scheme # Raise error if device does not support the scheme

View File

@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
) )
@ -101,7 +103,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4Nvfp4MoeMethod", "CompressedTensorsW4A4Nvfp4MoEMethod",
"CompressedTensorsW4A8Int8MoEMethod", "CompressedTensorsW4A8Int8MoEMethod",
] ]
@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method( def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, layer_name: str,
) -> "CompressedTensorsMoEMethod": ) -> "CompressedTensorsMoEMethod":
# FusedMoE was made by combining multiple Linears so need to # FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it # make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map() quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [ unfused_names = [
prefix + proj_name layer_name + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
] ]
# TODO: refactor this to use expert_mapping and check all layer numbers # TODO: refactor this to use expert_mapping and check all layer numbers
@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic." "WNA16MoE is not supported with actorder=group/dynamic."
) )
logger.info_once("Using CompressedTensorsWNA16MoEMethod") logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config) return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name
)
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod( return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config quant_config, layer.moe_config, layer_name
) )
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config) return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
elif ( elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)
): ):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config) return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config) return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
else: else:
raise RuntimeError( raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
) )
class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support, detect_nvfp4_moe_support,
) )
@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
self.flashinfer_moe_backend = None self.flashinfer_moe_backend = None
if self.allow_flashinfer: if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once( logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4Nvfp4MoeMethod." " for CompressedTensorsW4A4Nvfp4MoEMethod."
) )
elif self.use_marlin: elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.") logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
else: else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.") logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
def create_weights( def create_weights(
self, self,
@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
) )
if self.use_marlin: if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
return return
# w13 # w13
if ( if (
@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
): ):
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
) )
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, ( assert expert_map is None, (
"Expert Parallelism / expert_map " "Expert Parallelism / expert_map "
"is currently not supported for " "is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoeMethod." "CompressedTensorsW4A4Nvfp4MoEMethod."
) )
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
or self.is_fp8_w8a8_sm100 or self.is_fp8_w8a8_sm100
) )
self.disable_expert_map = False self.disable_expert_map = False
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
def create_weights( def create_weights(
self, self,
@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin: elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False) prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self.strategy = config.strategy self.strategy = config.strategy
self.group_size = config.group_size self.group_size = config.group_size
self.actorder = config.actorder self.actorder = config.actorder
self.layer_name = layer_name
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
assert config.symmetric, "Only symmetric quantization is supported for MoE" assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not ( if not (
@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_groups_w2 = w2_scales_size // self.group_size num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size num_groups_w13 = hidden_size // self.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
w13_scale = torch.nn.Parameter( w13_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_weight_g_idx.shape[0] num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device device = layer.w13_weight_g_idx.device
is_a_8bit = (
self.marlin_input_dtype is not None
and self.marlin_input_dtype.itemsize == 1
)
if self.marlin_input_dtype == torch.float8_e4m3fn:
# NOTE: for non-zp quantization format only
ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True)
layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512
layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512
# when running models with grouped act order, # when running models with grouped act order,
# resort to g_idx values provided in checkpoint # resort to g_idx values provided in checkpoint
@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2], layer.w13_weight_packed.shape[2],
self.num_bits, self.num_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2], layer.w2_weight_packed.shape[2],
self.num_bits, self.num_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_weight_scale, s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2], size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2], size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size, group_size=self.group_size,
is_a_8bit=is_a_8bit,
) )
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_weight_scale, s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1] size_k=layer.w2_weight_scale.shape[1]
* (self.group_size if self.group_size != -1 else self.packed_factor), * (self.group_size if self.group_size != -1 else self.packed_factor),
size_n=layer.w2_weight_scale.shape[2], size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size, group_size=self.group_size,
is_a_8bit=is_a_8bit,
) )
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
) )
@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.has_bias = self.moe.has_bias self.has_bias = self.moe.has_bias

View File

@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, MPLinearLayerConfig,
choose_mp_linear_kernel, choose_mp_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
marlin_repeat_scales_on_all_ranks, marlin_repeat_scales_on_all_ranks,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size: int | None = None, group_size: int | None = None,
symmetric: bool | None = True, symmetric: bool | None = True,
actorder: ActivationOrdering | None = None, actorder: ActivationOrdering | None = None,
layer_name: str | None = None,
): ):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.symmetric = symmetric self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP self.has_g_idx = actorder == ActivationOrdering.GROUP
self.layer_name = layer_name
if self.group_size == -1 and self.strategy != "channel": if self.group_size == -1 and self.strategy != "channel":
raise ValueError( raise ValueError(
@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__) logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__)
if isinstance(kernel_type, MarlinLinearKernel):
input_dtype = get_marlin_input_dtype(self.layer_name)
if input_dtype is not None:
mp_linear_kernel_config.act_type = input_dtype
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = input_size != input_size_per_partition row_parallel = input_size != input_size_per_partition

View File

@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return Fp8LinearMethod(self) quant_method = Fp8LinearMethod(self)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if is_layer_skipped( if is_layer_skipped(
prefix=prefix, prefix=prefix,
@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer) moe_quant_method = Fp8MoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
return None return None
@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = ( self.use_marlin = (
not current_platform.has_device_capability(89) not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN or envs.VLLM_TEST_FORCE_FP8_MARLIN
@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer, size_k_first) prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
return return
@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase):
workspace=layer.workspace, workspace=layer.workspace,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias, bias=bias,
) )
@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant, layer.moe_parallel_config self.block_quant, layer.moe_parallel_config
) )
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight.data = w13_weight.data layer.w13_weight.data = w13_weight.data
if self.use_marlin: if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False) prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:

View File

@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_marlin_supported,
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
marlin_permute_bias, marlin_permute_bias,
@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method( return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod) moe_quant_method = get_moe_quant_method(
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) self, layer, prefix, GPTQMarlinMoEMethod
)
if moe_quant_method is None:
return None
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
quant_method = get_linear_quant_method(
self, layer, prefix, GPTQMarlinLinearMethod
)
if quant_method is None:
return None
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
@classmethod @classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.input_dtype = None
self.quant_type = self.quant_config.quant_type
# Verify supported on platform. # Verify supported on platform.
verify_marlin_supported( verify_marlin_supported(
@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
input_dtype = self.input_dtype
mp_linear_kernel_config = MPLinearLayerConfig( mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size), full_weight_shape=(input_size, output_size),
@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition, output_size_per_partition,
), ),
weight_type=self.quant_config.quant_type, weight_type=self.quant_config.quant_type,
act_type=params_dtype, act_type=params_dtype if input_dtype is None else input_dtype,
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
zero_points=False, zero_points=False,
has_g_idx=self.quant_config.desc_act, has_g_idx=self.quant_config.desc_act,
@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_type = scalar_types.uint8b128 self.quant_type = scalar_types.uint8b128
else: else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True self.use_marlin = True
def create_weights( def create_weights(
@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
layer.input_dtype = self.input_dtype
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or ( self.is_k_full = (not self.quant_config.desc_act) or (
@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
scales_size2 = 1 scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value strategy = FusedMoeWeightScaleSupported.CHANNEL.value
layer.num_groups_w13 = scales_size13
layer.num_groups_w2 = scales_size2
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter( w13_qweight = torch.nn.Parameter(
@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order # Process act_order
if self.quant_config.desc_act: if self.quant_config.desc_act:
# Get sorting based on g_idx # Get sorting based on g_idx
@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2], layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = ops.gptq_marlin_moe_repack(
@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2], layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales # Repack scales
@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition, size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2], size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales) replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales, s=layer.w2_scales,
@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
), ),
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None: if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
) )

View File

@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase):
bias, bias,
scales, scales,
None, None,
None,
zeros, zeros,
layer.g_idx, layer.g_idx,
layer.g_idx_sort_indices, layer.g_idx_sort_indices,

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear, apply_gptq_marlin_linear,
check_marlin_supports_shape, check_marlin_supports_shape,
marlin_act_int8_process_scales,
marlin_is_k_full, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
) )
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device device = getattr(layer, self.w_q_name).device
c = self.config c = self.config
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
if is_a_8bit:
assert c.weight_type == scalar_types.uint4b8, (
"W8A8 is not supported by marlin kernel."
)
if c.act_type == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0], size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1], size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits, num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
) )
return x return x
@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0], size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1], size_n=c.partition_weight_shape[1],
group_size=c.group_size, group_size=c.group_size,
is_a_8bit=is_a_8bit,
) )
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
return x return x
if c.has_g_idx: if c.has_g_idx:
@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=grouped_k, size_k=grouped_k,
size_n=c.partition_weight_shape[1], size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits, num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
), ),
) )
else: else:
@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel):
# `process_weights_after_loading` will ensure w_zp and w_gidx are not # `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin # None for marlin
return apply_gptq_marlin_linear( return apply_gptq_marlin_linear(
input=x, input=x,
weight=w_q, weight=w_q,
@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel):
input_size_per_partition=c.partition_weight_shape[0], input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1], output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias, bias=bias,
input_dtype=c.act_type,
) )

View File

@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
select_cutlass_fp8_gemm_impl, select_cutlass_fp8_gemm_impl,
swap_w13_to_w31, swap_w13_to_w31,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, apply_fp4_marlin_linear,
is_fp4_marlin_supported, is_fp4_marlin_supported,
@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# now, the layer is quantized, handle it here # now, the layer is quantized, handle it here
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return self.LinearMethodCls(self) quant_method = self.LinearMethodCls(self)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return self.FusedMoEMethodCls(quant_config=self, layer=layer) quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
return None return None
@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None: def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.marlin_input_dtype = None
self.backend = "none" self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None: if envs.VLLM_NVFP4_GEMM_BACKEND is None:
@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias, bias=bias,
input_dtype=self.marlin_input_dtype,
) )
output_dtype = x.dtype output_dtype = x.dtype
@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.marlin_input_dtype = None
self.flashinfer_moe_backend = None self.flashinfer_moe_backend = None
if self.allow_flashinfer: if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
workspace=layer.workspace, input_dtype=self.marlin_input_dtype,
) )
elif self.allow_flashinfer: elif self.allow_flashinfer:

View File

@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin,
) )
@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
if current_platform.is_xpu(): if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config) return IpexMxfp4MoEMethod(layer.moe_config)
else: else:
return Mxfp4MoEMethod(layer.moe_config) quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention. # TODO: Add support for MXFP4 Attention.
logger.debug_once( logger.debug_once(
@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = ( self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size get_current_vllm_config().compilation_config.max_cudagraph_capture_size
@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
elif ( elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
activation=activation, activation=activation,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
) )
assert _can_support_mxfp4( assert _can_support_mxfp4(

View File

@ -9,6 +9,11 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
@ -286,10 +291,10 @@ def get_scale_perms():
def marlin_permute_scales( def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms() scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1: if group_size < size_k and group_size != -1 and not is_a_8bit:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm] s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else: else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
return s.reshape(*origin_shape).contiguous() return s.reshape(*origin_shape).contiguous()
def marlin_act_int8_process_scales(s: torch.Tensor):
a_scales_scale_factor = 1 / 4096 * s.max().float()
s = s / s.max() * 4096
s = s.round().to(torch.int16).view(s.dtype)
return s, a_scales_scale_factor
def marlin_moe_permute_scales( def marlin_moe_permute_scales(
s: torch.Tensor, s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
size_k: int,
size_n: int,
group_size: int,
): ):
num_experts = s.shape[0] num_experts = s.shape[0]
output = torch.empty( output = torch.empty(
@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
return output return output
def marlin_zero_points( def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the # Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA # "single" permutation, since zero-points are applied on every MMA
@ -339,7 +348,8 @@ def marlin_zero_points(
else: else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() if not is_a_8bit:
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous() zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n) zp = pack_cols(zp, num_bits, size_k, size_n)
@ -347,7 +357,11 @@ def marlin_zero_points(
def awq_to_marlin_zero_points( def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim. # AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer. # In addition, the values are permuted based on dequantizer.
@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous() q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
return marlin_zp return marlin_zp
def moe_awq_to_marlin_zero_points( def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
): ):
num_experts = q_zp_packed.shape[0] num_experts = q_zp_packed.shape[0]
output = torch.empty( output = torch.empty(
@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
dtype=q_zp_packed.dtype, dtype=q_zp_packed.dtype,
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) output[e] = awq_to_marlin_zero_points(
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
)
return output return output
@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
return True return True
_quant_fp8_method: QuantFP8 | None = None
def get__quant_fp8_method() -> QuantFP8:
global _quant_fp8_method
if _quant_fp8_method is None:
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
return _quant_fp8_method
def get_marlin_input_dtype(prefix):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
return torch.int8
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
)
_ = get__quant_fp8_method()
return torch.float8_e4m3fn
else:
return
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
x = x.reshape(-1, x.shape[-1])
if quant_dtype == torch.int8:
return per_token_quant_int8(x)
elif quant_dtype == torch.float8_e4m3fn:
return get__quant_fp8_method()(x)
else:
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
def apply_gptq_marlin_linear( def apply_gptq_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
is_k_full: bool, is_k_full: bool,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,) out_shape = input.shape[:-1] + (output_size_per_partition,)
@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
a_scales = None
if input_dtype == torch.int8:
assert wtype == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert wtype == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
bias, bias,
weight_scale, weight_scale,
a_scales,
None, None,
weight_zp, weight_zp,
g_idx, g_idx,
@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
quant_type: ScalarType, quant_type: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,) out_shape = input.shape[:-1] + (output_size_per_partition,)
@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
bias, bias,
weight_scale, weight_scale,
a_scales,
None, None,
weight_zp, weight_zp,
g_idx, g_idx,
@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
quant_type: ScalarType, quant_type: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,) out_shape = input.shape[:-1] + (output_size_per_partition,)
@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
bias, bias,
weight_scale, weight_scale,
a_scales,
None, None,
None, None,
None, None,

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales):
# convert to half first, we would convert to fp8 later # convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half) marlin_scales = marlin_scales.to(torch.half)
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
# fit the layout of fp8 dequantization # fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1 marlin_scales.size(0), -1
@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales):
return marlin_scales return marlin_scales
def mxfp4_marlin_process_scales(marlin_scales): def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
# fit the layout of fp8 dequantization # fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( if input_dtype is None or input_dtype.itemsize == 2:
marlin_scales.size(0), -1 marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
) marlin_scales.size(0), -1
)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
if input_dtype == torch.float8_e4m3fn:
marlin_scales = marlin_scales.view(torch.uint8)
assert marlin_scales.max() <= 249
# exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6
marlin_scales = marlin_scales + 6
marlin_scales = marlin_scales.view(torch.float8_e8m0fnu)
return marlin_scales return marlin_scales
@ -99,6 +96,7 @@ def apply_fp4_marlin_linear(
size_n: int, size_n: int,
size_k: int, size_k: int,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor: ) -> torch.Tensor:
# For GPUs that lack FP4 hardware support, we can leverage the # For GPUs that lack FP4 hardware support, we can leverage the
@ -111,12 +109,24 @@ def apply_fp4_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
) )
inputs = reshaped_x
a_scales = None
is_nvfp4 = weight_scale_2 is not None
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a=reshaped_x, a=inputs,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
b_bias=bias, b_bias=bias,
b_scales=weight_scale, b_scales=weight_scale,
a_scales=a_scales,
global_scale=weight_scale_2, global_scale=weight_scale_2,
b_zeros=None, b_zeros=None,
g_idx=None, g_idx=None,
@ -133,7 +143,9 @@ def apply_fp4_marlin_linear(
return output.reshape(out_shape) return output.reshape(out_shape)
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: def prepare_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP4 computation but " "Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will " "FP4 quantization is being used. Weight-only FP4 compression will "
@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
perm = torch.empty(0, dtype=torch.int, device=device) perm = torch.empty(0, dtype=torch.int, device=device)
qweight = layer.weight.view(torch.int32).T.contiguous() qweight = layer.weight.view(torch.int32).T.contiguous()
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, b_q_weight=qweight,
perm=perm, perm=perm,
size_k=part_size_k, size_k=part_size_k,
size_n=part_size_n, size_n=part_size_n,
num_bits=4, num_bits=4,
is_a_8bit=is_a_8bit,
) )
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale = weight_scale.to(param_dtype) weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales( weight_scale = marlin_permute_scales(
s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
if is_nvfp4: if is_nvfp4:
@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
else: else:
weight_scale = mxfp4_marlin_process_scales(weight_scale) weight_scale = mxfp4_marlin_process_scales(
weight_scale, input_dtype=input_dtype
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None: if hasattr(layer, "bias") and layer.bias is not None:
@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
return return
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP4 computation but " "Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will " "FP4 quantization is being used. Weight-only FP4 compression will "
@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
param_dtype = layer.params_dtype param_dtype = layer.params_dtype
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device) perm = torch.empty(0, dtype=torch.int, device=device)
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WEIGHT # WEIGHT
# Repack weights to marlin format # Repack weights to marlin format
@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
qweight = weight[i].view(torch.int32).T.contiguous() qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
) )
tensor_list.append(marlin_qweight) tensor_list.append(marlin_qweight)
@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
scale = scales[i].T scale = scales[i].T
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scale, size_k=size_k, size_n=size_n, group_size=group_size s=scale,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
if is_nvfp4: if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales) marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else: else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales) marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
)
tensor_list.append(marlin_scales) tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
setattr(layer, name, bias) setattr(layer, name, bias)
def rand_marlin_weight_nvfp4_like(weight, group_size): def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
assert not is_a_8bit, "NVFP4 weight + INT8/FP8 activation is not supported."
assert group_size > 0 assert group_size > 0
size_n, size_k = weight.shape size_n, size_k = weight.shape
device = weight.device device = weight.device
@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
size_k=size_k, size_k=size_k,
size_n=size_n, size_n=size_n,
num_bits=4, num_bits=4,
is_a_8bit=is_a_8bit,
) )
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
marlin_scales = nvfp4_marlin_process_scales(marlin_scales) marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
return weight_ref.T, marlin_qweight, marlin_scales, global_scale return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size): def rand_marlin_weight_mxfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn, (
"MXFP4 weight + INT8 activation is not supported."
)
assert group_size > 0 assert group_size > 0
size_n, size_k = weight.shape size_n, size_k = weight.shape
device = weight.device device = weight.device
scales = torch.randint( scales = torch.randint(
100, 110,
125, 120,
(size_n, size_k // group_size), (size_n, size_k // group_size),
dtype=torch.uint8, dtype=torch.uint8,
device=weight.device, device=weight.device,
@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size):
).view(size_n, size_k) ).view(size_n, size_k)
weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype)
perm = torch.empty(0, dtype=torch.int, device=device)
fp4_weight = fp4_weight.view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), b_q_weight=fp4_weight,
perm=torch.empty(0, dtype=torch.int, device=device), perm=perm,
size_k=size_k, size_k=size_k,
size_n=size_n, size_n=size_n,
num_bits=4, num_bits=4,
is_a_8bit=is_a_8bit,
) )
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
marlin_scales = mxfp4_marlin_process_scales(marlin_scales) marlin_scales = mxfp4_marlin_process_scales(marlin_scales, input_dtype=input_dtype)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -45,6 +46,7 @@ def apply_fp8_marlin_linear(
size_n: int, size_n: int,
size_k: int, size_k: int,
bias: torch.Tensor | None, bias: torch.Tensor | None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor: ) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the # For GPUs that lack FP8 hardware support, we can leverage the
@ -57,12 +59,21 @@ def apply_fp8_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
) )
inputs = reshaped_x
a_scales = None
if input_dtype is not None and input_dtype.itemsize == 1:
if input_dtype != torch.float8_e4m3fn:
raise RuntimeError("FP8 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a=reshaped_x, a=reshaped_x,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
b_bias=bias, b_bias=bias,
b_scales=weight_scale, b_scales=weight_scale,
a_scales=a_scales,
global_scale=None, global_scale=None,
b_zeros=None, b_zeros=None,
g_idx=None, g_idx=None,
@ -80,7 +91,9 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin( def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None: ) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin(
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
) )
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None: if hasattr(layer, "bias") and layer.bias is not None:
@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin(
def prepare_moe_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None: ) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin(
tensor_list.append(marlin_scales) tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = fp8_fused_exponent_bias_into_scales(scales) if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False) scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales) setattr(layer, name + "_weight_scale", scales)
@ -318,7 +335,11 @@ def pack_fp8_to_int32(
return int32_tensor.T.contiguous() if size_k_first else int32_tensor return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def marlin_quant_fp8_torch(weight, group_size): def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn
size_n, size_k = weight.shape size_n, size_k = weight.shape
device = weight.device device = weight.device
@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size):
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
perm = torch.empty(0, dtype=torch.int, device=device)
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight, b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device), perm=perm,
size_k=size_k, size_k=size_k,
size_n=size_n, size_n=size_n,
num_bits=8, num_bits=8,
is_a_8bit=is_a_8bit,
) )
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size s=scales.T,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)

View File

@ -5,7 +5,8 @@
import numpy as np import numpy as np
import torch import torch
from vllm.scalar_type import ScalarType from vllm import _custom_ops as ops
from vllm.scalar_type import ScalarType, scalar_types
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
from .quant_utils import ( from .quant_utils import (
@ -29,13 +30,19 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): def marlin_permute_weights(
q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False
):
assert q_w.shape == (size_k, size_n) assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles if is_a_8bit:
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) # Permute weights to 32x32 marlin tiles
q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile))
else:
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile)) q_w = q_w.reshape((size_k // tile, size_n * tile))
@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
return q_w return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm): def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False):
# Permute # Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm) q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit)
# Pack # Pack
pack_factor = get_pack_factor(num_bits) pack_factor = get_pack_factor(num_bits)
@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
return q_packed return q_packed
def get_weight_perm(num_bits: int): def get_weight_perm(num_bits: int, is_a_8bit: bool = False):
perm_list: list[int] = [] perm_list: list[int] = []
for i in range(32): if is_a_8bit:
perm1: list[int] = [] for i in range(32):
col = i // 4 perm1 = []
for block in [0, 1]: col = i // 4
for row in [ for block in [0, 1]:
2 * (i % 4), for row in [
2 * (i % 4) + 1, 4 * (i % 4),
2 * (i % 4 + 4), 4 * (i % 4) + 1,
2 * (i % 4 + 4) + 1, 4 * (i % 4) + 2,
]: 4 * (i % 4) + 3,
perm1.append(16 * row + col + 8 * block) 4 * (i % 4 + 4),
for j in range(4): 4 * (i % 4 + 4) + 1,
perm_list.extend([p + 256 * j for p in perm1]) 4 * (i % 4 + 4) + 2,
4 * (i % 4 + 4) + 3,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(2):
perm_list.extend([p + 512 * j for p in perm1])
else:
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list) perm = np.array(perm_list)
if num_bits == 4: if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
else:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8: elif num_bits == 8:
interleave = np.array([0, 2, 1, 3]) if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 1, 2, 3])
else:
interleave = np.array([0, 2, 1, 3])
else: else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
@ -99,7 +131,10 @@ def marlin_quantize(
group_size: int, group_size: int,
act_order: bool, act_order: bool,
test_perm: torch.Tensor | None = None, test_perm: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
): ):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape size_k, size_n = w.shape
num_bits = quant_type.size_bits num_bits = quant_type.size_bits
@ -120,9 +155,15 @@ def marlin_quantize(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin # Reformat to marlin
weight_perm = get_weight_perm(num_bits) weight_perm = get_weight_perm(num_bits, is_a_8bit)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) marlin_q_w = marlin_weights(
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit
)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8:
ops.marlin_int4_fp8_preprocess(marlin_q_w, inplace=True)
marlin_s = marlin_s * 512
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
@ -132,7 +173,13 @@ def marlin_quantize(
return res_list return res_list
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): def awq_marlin_quantize(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
input_dtype: torch.dtype | None = None,
):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape size_k, size_n = w.shape
# Normalize group_size # Normalize group_size
@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int
# Quantize with zp # Quantize with zp
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4:
repeated_zp = zp.repeat_interleave(group_size, 0)
q_w_old = q_w
q_w = q_w_old - repeated_zp
q_w[q_w < 0] = 15 - q_w_old[q_w < 0]
s = s * 512
# Reformat to marlin # Reformat to marlin
weight_perm = get_weight_perm(quant_type.size_bits) weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) marlin_q_w = marlin_weights(
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit=is_a_8bit
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) )
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
marlin_zp = marlin_zero_points(
zp, num_groups, size_n, quant_type.size_bits, is_a_8bit=is_a_8bit
)
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]