mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 03:57:14 +08:00
[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
parent
fa59fe417f
commit
1656ad3704
100
CMakeLists.txt
100
CMakeLists.txt
@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
|
||||
# marlin arches for fp16 output
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
|
||||
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
|
||||
#
|
||||
@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
|
||||
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
|
||||
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
|
||||
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
|
||||
|
||||
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
|
||||
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
|
||||
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE marlin_generation_result
|
||||
OUTPUT_VARIABLE marlin_generation_result
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||
@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
|
||||
else()
|
||||
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin generate script hash" FORCE)
|
||||
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
|
||||
CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
|
||||
message(STATUS "Marlin generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
|
||||
endif()
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -941,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin arches
|
||||
# note that we always set `use_atomic_add=False` for moe marlin now,
|
||||
# so we don't need 9.0 for bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
|
||||
#
|
||||
@ -952,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MOE_MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
|
||||
|
||||
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
@ -974,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||
else()
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
|
||||
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||
message(STATUS "Marlin MOE generation completed successfully.")
|
||||
endif()
|
||||
@ -982,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
if (MARLIN_MOE_FP8_ARCHS)
|
||||
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_FP8_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_FP8_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
|
||||
@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
b_q_weight=w_q,
|
||||
b_bias=None,
|
||||
b_scales=w_s,
|
||||
a_scales=None,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
|
||||
@ -263,7 +263,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -273,7 +273,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
3
csrc/moe/marlin_moe_wna16/.gitignore
vendored
3
csrc/moe/marlin_moe_wna16/.gitignore
vendored
@ -1 +1,2 @@
|
||||
kernel_*.cu
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
|
||||
@ -4,134 +4,282 @@ import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
// clang-format off
|
||||
""".lstrip()
|
||||
|
||||
FILE_HEAD = (
|
||||
FILE_HEAD_COMMENT
|
||||
+ """
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
"""
|
||||
)
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{a_type_id}}, "
|
||||
"{{b_type_id}}, "
|
||||
"{{c_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{m_block_size_8}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"{{is_zp_float}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4",
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
"vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f",
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
QUANT_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT8
|
||||
{
|
||||
"b_type": "kU8B128",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{
|
||||
"b_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 8],
|
||||
},
|
||||
# NVFP4
|
||||
{
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [1],
|
||||
},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kFE2M1f",
|
||||
"c_type": ["kBFloat16"],
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
filename = os.path.dirname(__file__) + "/kernel_selector.h"
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
|
||||
b_type = quant_config["b_type"]
|
||||
all_group_blocks = quant_config["group_blocks"]
|
||||
all_m_blocks = quant_config["thread_m_blocks"]
|
||||
all_thread_configs = quant_config["thread_configs"]
|
||||
|
||||
for a_type, c_type in itertools.product(a_types, c_types):
|
||||
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
|
||||
continue
|
||||
if "16" in a_type and "16" in c_type and a_type != c_type:
|
||||
continue
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
):
|
||||
thread_k, thread_n, threads = thread_configs
|
||||
|
||||
if threads == 256:
|
||||
# for small batch (m_blocks == 1),
|
||||
# we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1),
|
||||
# we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
|
||||
continue
|
||||
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
|
||||
continue
|
||||
|
||||
config = {
|
||||
"threads": threads,
|
||||
"s_type": s_type,
|
||||
"thread_m_blocks": max(m_blocks, 1),
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
# we only support channelwise quantization and group_size == 128
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
"else if (a_type == vllm::kFE4M3fn)\n"
|
||||
" TORCH_CHECK(false, "
|
||||
'"marlin kernel with fp8 activation is not built.");'
|
||||
)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
|
||||
f.write(kernel_selector_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
|
||||
@ -11,8 +11,9 @@
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
@ -20,12 +21,13 @@
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
|
||||
const vllm::ScalarTypeId b_type_id, // B ScalarType id
|
||||
const vllm::ScalarTypeId c_type_id, // C ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
template <int moe_block_size>
|
||||
@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float) {
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 4;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem) {
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
// this is the most common cases
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define COMMON_GET_IF(W_TYPE) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF(W_TYPE) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF(W_TYPE) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
|
||||
COMMON_GET_IF(vllm::kU4)
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
#include "kernel_selector.h"
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem) {
|
||||
exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : (group_size / 16);
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
auto kernel =
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
if (thread_m_blocks > 1) {
|
||||
exec_cfg = {1, th_config};
|
||||
break;
|
||||
} else {
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1024));
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1536));
|
||||
if (thread_m_blocks == 1)
|
||||
allow_count = max(min(allow_count, 4), 1);
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
else
|
||||
allow_count = max(min(allow_count, 2), 1);
|
||||
|
||||
if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
|
||||
allow_count =
|
||||
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
|
||||
}
|
||||
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
|
||||
void* perm, void* a_tmp, void* sorted_token_ids,
|
||||
void* expert_ids, void* num_tokens_past_padded,
|
||||
void* topk_weights, int moe_block_size, int num_experts,
|
||||
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
|
||||
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
|
||||
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int dev, cudaStream_t stream,
|
||||
int thread_k, int thread_n, int sms, int blocks_per_sm,
|
||||
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
int num_bits = b_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
int major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
major_capability * 10 + minor_capability == 120,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
|
||||
"Marlin W4A16 on other devices).");
|
||||
}
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||
exec_cfg = exec_config_t{1, thread_tfg};
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
|
||||
if (blocks_per_sm == -1) blocks_per_sm = 1;
|
||||
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem);
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
|
||||
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size, ", has_act_order = ", has_act_order,
|
||||
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
|
||||
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& a_scales_or_none,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
bool is_zp_float, int64_t thread_k, int64_t thread_n,
|
||||
int64_t blocks_per_sm) {
|
||||
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
|
||||
|
||||
auto c_dtype = a.dtype();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
a_type_id = vllm::kFloat16.id();
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
a_type_id = vllm::kBFloat16.id();
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_dtype = b_scales.dtype();
|
||||
if (b_scales.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
|
||||
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
|
||||
torch::Tensor c = c_or_none.value();
|
||||
c_dtype = c.dtype();
|
||||
|
||||
if (c.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported c dtype");
|
||||
}
|
||||
}
|
||||
|
||||
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
a_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::Char) {
|
||||
a_type_id = vllm::kS8.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported `a` scalar_type");
|
||||
}
|
||||
}
|
||||
|
||||
s_type_id = c_type_id;
|
||||
if (b_type_id == vllm::kFE2M1f.id()) {
|
||||
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
s_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
|
||||
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
|
||||
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
|
||||
int pack_factor = 32 / b_type.size_bits();
|
||||
int num_experts = b_q_weight.size(0);
|
||||
|
||||
if (moe_block_size != 8) {
|
||||
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||
@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_n = -1;
|
||||
torch::Tensor a_scales;
|
||||
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
|
||||
if (a_scales_or_none.has_value()) {
|
||||
a_scales = a_scales_or_none.value();
|
||||
TORCH_CHECK(a_type.size_bits() == 8,
|
||||
"a_scales can only be used for 8bit activation.");
|
||||
} else {
|
||||
a_scales = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(a_type.size_bits() != 8,
|
||||
"the a_scales parameter must be passed for 8bit activation.");
|
||||
}
|
||||
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce && !use_atomic_add) {
|
||||
// max num of threadblocks is sms * 4
|
||||
long max_c_tmp_size = min(
|
||||
@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
b_type == vllm::kU4 || b_type == vllm::kU8,
|
||||
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
|
||||
"b_type must be uint4b8, uint8b128, int4, int8, "
|
||||
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
|
||||
b_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
"scalar type of a must be the same with c for 16 bit activation");
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
|
||||
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
|
||||
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
|
||||
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
|
||||
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
|
||||
@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
|
||||
"Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
|
||||
"b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"bool mul_topk_weights, bool is_ep, int b_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
"bool use_fp32_reduce, bool is_zp_float,"
|
||||
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
|
||||
@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < 2; ++k_idx) {
|
||||
FType low16 =
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType high16 =
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
FType low16 = MarlinScalarType2<FType>::float2num(
|
||||
C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType high16 = MarlinScalarType2<FType>::float2num(
|
||||
C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
|
||||
(reinterpret_cast<uint32_t&>(high16) << 16);
|
||||
int sts_offset =
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
using marlin::ScalarType;
|
||||
using marlin::MarlinScalarType2;
|
||||
|
||||
namespace allspark {
|
||||
|
||||
@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
|
||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||
sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||
}
|
||||
|
||||
C[idx] = ScalarType<FType>::float2num(sum);
|
||||
C[idx] = MarlinScalarType2<FType>::float2num(sum);
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
|
||||
3
csrc/quantization/gptq_marlin/.gitignore
vendored
3
csrc/quantization/gptq_marlin/.gitignore
vendored
@ -1 +1,2 @@
|
||||
kernel_*.cu
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
|
||||
@ -4,14 +4,16 @@
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
template <int const num_threads, int const num_bits, bool is_a_8bit>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
|
||||
constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
|
||||
int k_tiles = size_k / target_tile_k_size;
|
||||
int n_tiles = size_n / target_tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel(
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||
constexpr int tile_n_ints = target_tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = tile_k_size;
|
||||
constexpr int stage_k_threads = target_tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n = n_tile_id * target_tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k = k_tile_id * target_tile_k_size;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel(
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel(
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
if constexpr (is_a_8bit) {
|
||||
int cur_elem = tc_row + i;
|
||||
|
||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||
sh_stride * cur_elem];
|
||||
int packed_src_0 =
|
||||
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
|
||||
sh_stride * cur_elem];
|
||||
int packed_src_1 =
|
||||
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
|
||||
sh_stride * (cur_elem + 16)];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
} else {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 =
|
||||
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||
sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
constexpr int tile_size =
|
||||
target_tile_k_size * target_tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
if constexpr (!is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else if constexpr (is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel(
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
const int ii = is_a_8bit ? i : pack_idx[i];
|
||||
res1 |= vals[ii] << (i * 8);
|
||||
res2 |= vals[4 + ii] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel(
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS, IS_A_8BIT) \
|
||||
else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
IS_A_8BIT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
IS_A_8BIT> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits) {
|
||||
int64_t size_n, int64_t num_bits,
|
||||
bool is_a_8bit) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4)
|
||||
CALL_IF(8)
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
", is_a_8bit = ", is_a_8bit);
|
||||
}
|
||||
|
||||
return out;
|
||||
|
||||
@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
|
||||
int q, __nv_fp8x4_e4m3* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
|
||||
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70707070;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note1: reverse indexing is intentional because weights are permuted
|
||||
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
|
||||
// `frag_b[1]` to fit the layout of tensorcore
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
|
||||
int q, int32_t* frag_b) {
|
||||
constexpr int repeated_zp = 0x08080808;
|
||||
constexpr int MASK = 0x80808080;
|
||||
|
||||
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
q >>= 4;
|
||||
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
|
||||
int q, __nv_fp8x4_e4m3* frag_b) {
|
||||
int s = q & 0x08080808;
|
||||
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
|
||||
q >>= 4;
|
||||
s = q & 0x08080808;
|
||||
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
};
|
||||
|
||||
// subtract zero point in quanted format and then dequant
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
|
||||
|
||||
template <>
|
||||
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
|
||||
int q, int32_t* frag_b, int zp) {
|
||||
// INT4 with zp -> INT8
|
||||
// see https://github.com/vllm-project/vllm/pull/24722
|
||||
int repeated_zp = 0x01010101 * zp;
|
||||
int MASK = 0x80808080;
|
||||
|
||||
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
q >>= 4;
|
||||
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
|
||||
true>(int q, __nv_fp8x4_e4m3* frag_b,
|
||||
int zp) {
|
||||
// INT4 with zp -> FP8
|
||||
// see https://github.com/vllm-project/vllm/pull/24722
|
||||
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
|
||||
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
|
||||
uint32_t u_zp1 = u_zp + 1;
|
||||
uint32_t repeated_zp = 0x01010101 * u_zp;
|
||||
|
||||
uint32_t q0, s;
|
||||
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
|
||||
s = (q0 + repeated_zp) & 0x80808080;
|
||||
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
|
||||
|
||||
u_q >>= 4;
|
||||
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
|
||||
s = (q0 + repeated_zp) & 0x80808080;
|
||||
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
|
||||
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -4,141 +4,292 @@ import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
// clang-format off
|
||||
""".lstrip()
|
||||
|
||||
FILE_HEAD = (
|
||||
FILE_HEAD_COMMENT
|
||||
+ """
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
"""
|
||||
)
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{a_type_id}}, "
|
||||
"{{b_type_id}}, "
|
||||
"{{c_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{m_block_size_8}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"{{is_zp_float}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4",
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
"vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f",
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
QUANT_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# HQQ
|
||||
{
|
||||
"a_type": ["kFloat16"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [4],
|
||||
"is_zp_float": True,
|
||||
},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": "kU8B128",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{
|
||||
"b_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 8],
|
||||
},
|
||||
# NVFP4
|
||||
{
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [1],
|
||||
},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kFE2M1f",
|
||||
"c_type": ["kBFloat16"],
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
filename = os.path.dirname(__file__) + "/kernel_selector.h"
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
|
||||
b_type = quant_config["b_type"]
|
||||
is_zp_float = quant_config.get("is_zp_float", False)
|
||||
all_group_blocks = quant_config["group_blocks"]
|
||||
all_m_blocks = quant_config["thread_m_blocks"]
|
||||
all_thread_configs = quant_config["thread_configs"]
|
||||
|
||||
for a_type, c_type in itertools.product(a_types, c_types):
|
||||
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
|
||||
continue
|
||||
if "16" in a_type and "16" in c_type and a_type != c_type:
|
||||
continue
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
):
|
||||
thread_k, thread_n, threads = thread_configs
|
||||
|
||||
if threads == 256:
|
||||
# for small batch (m_blocks == 1),
|
||||
# we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1),
|
||||
# we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
|
||||
continue
|
||||
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
|
||||
continue
|
||||
|
||||
config = {
|
||||
"threads": threads,
|
||||
"s_type": s_type,
|
||||
"thread_m_blocks": max(m_blocks, 1),
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
# we only support channelwise quantization and group_size == 128
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
is_zp_float_list = [False]
|
||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
|
||||
# HQQ (is_zp_float = true) only supports
|
||||
# 4bit quantization and fp16
|
||||
is_zp_float_list.append(True)
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
for is_zp_float in is_zp_float_list:
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=is_zp_float,
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
"else if (a_type == vllm::kFE4M3fn)\n"
|
||||
" TORCH_CHECK(false, "
|
||||
'"marlin kernel with fp8 activation is not built.");'
|
||||
)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
|
||||
f.write(kernel_selector_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
|
||||
@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm(
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
// this is the most common cases
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define COMMON_GET_IF(W_TYPE) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF(W_TYPE) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF(W_TYPE) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
|
||||
COMMON_GET_IF(vllm::kU4)
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
|
||||
if (std::is_same<scalar_t, half>::value) {
|
||||
if (false) {
|
||||
}
|
||||
FZP_GET_IF(vllm::kU4)
|
||||
}
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
#include "kernel_selector.h"
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem,
|
||||
int sms) {
|
||||
exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
auto kernel =
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
|
||||
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
|
||||
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
|
||||
void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
|
||||
int lda, void* workspace, vllm::ScalarType const& a_type,
|
||||
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
|
||||
vllm::ScalarType const& s_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
int num_bits = b_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
int major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
major_capability * 10 + minor_capability == 120,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
|
||||
"Marlin W4A16 on other devices).");
|
||||
}
|
||||
|
||||
int max_par = 16;
|
||||
if (prob_n <= 4096) max_par = 16 * 8;
|
||||
int max_shared_mem_new = max_shared_mem;
|
||||
@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int thread_n = thread_n_init;
|
||||
|
||||
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
|
||||
int m_block_size_8 = prob_m_split <= 8;
|
||||
int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, sms);
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
|
||||
sms) {
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
|
||||
max_thread_m_blocks--;
|
||||
continue;
|
||||
@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks, num_threads,
|
||||
is_zp_float);
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
|
||||
g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
|
||||
A_ptr += prob_m_split * (lda / 8);
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
|
||||
a_s_ptr += prob_m_split;
|
||||
C_ptr += prob_m_split * (prob_n / 8);
|
||||
rest_m -= prob_m_split;
|
||||
}
|
||||
@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& a_scales_or_none,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
|
||||
|
||||
auto c_dtype = a.dtype();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
a_type_id = vllm::kFloat16.id();
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
a_type_id = vllm::kBFloat16.id();
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_dtype = b_scales.dtype();
|
||||
if (b_scales.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
|
||||
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
|
||||
torch::Tensor c = c_or_none.value();
|
||||
c_dtype = c.dtype();
|
||||
|
||||
if (c.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported c dtype");
|
||||
}
|
||||
}
|
||||
|
||||
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
a_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::Char) {
|
||||
a_type_id = vllm::kS8.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported `a` scalar_type");
|
||||
}
|
||||
}
|
||||
|
||||
s_type_id = c_type_id;
|
||||
if (b_type_id == vllm::kFE2M1f.id()) {
|
||||
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
s_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
|
||||
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
|
||||
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
|
||||
int pack_factor = 32 / b_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm(
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
torch::Tensor a_scales;
|
||||
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
|
||||
if (a_scales_or_none.has_value()) {
|
||||
a_scales = a_scales_or_none.value();
|
||||
TORCH_CHECK(a_type.size_bits() == 8,
|
||||
"a_scales can only be used for 8bit activation.");
|
||||
} else {
|
||||
a_scales = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(a_type.size_bits() != 8,
|
||||
"the a_scales parameter must be passed for 8bit activation.");
|
||||
}
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm(
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm(
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce) {
|
||||
int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
|
||||
max_m_block_size = min(max_m_block_size, 64);
|
||||
@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm(
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
b_type == vllm::kU4 || b_type == vllm::kU8,
|
||||
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
|
||||
"b_type must be uint4b8, uint8b128, int4, int8, "
|
||||
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
|
||||
b_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm(
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
"scalar type of a must be the same with c for 16 bit activation");
|
||||
}
|
||||
|
||||
marlin::marlin_mm(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
|
||||
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
|
||||
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
@ -4,15 +4,18 @@
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
template <int const num_threads, int const num_bits, bool const has_perm,
|
||||
bool is_a_8bit>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
|
||||
constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
|
||||
int k_tiles = size_k / target_tile_k_size;
|
||||
int n_tiles = size_n / target_tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
constexpr int perm_size = target_tile_k_size / 4;
|
||||
|
||||
int4* sh_perm_ptr = sh;
|
||||
int4* sh_pipe_ptr = sh_perm_ptr;
|
||||
@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
||||
constexpr int tile_ints = target_tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
||||
constexpr int stage_n_threads = target_tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
|
||||
|
||||
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
||||
|
||||
@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n = n_tile_id * target_tile_n_size;
|
||||
|
||||
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k = k_tile_id * target_tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
constexpr int sh_stride = target_tile_n_size;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
static_assert(!is_a_8bit);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
if constexpr (is_a_8bit) {
|
||||
b1_vals[i] =
|
||||
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
|
||||
} else {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
if constexpr (is_a_8bit)
|
||||
vals[4 + i] =
|
||||
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
|
||||
else
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
constexpr int tile_size =
|
||||
target_tile_k_size * target_tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
if constexpr (!is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else if constexpr (is_a_8bit && num_bits == 4) {
|
||||
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
const int ii = is_a_8bit ? i : pack_idx[i];
|
||||
res1 |= vals[ii] << (i * 8);
|
||||
res2 |= vals[4 + ii] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
|
||||
is_a_8bit == IS_A_8BIT) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
HAS_PERM, IS_A_8BIT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
HAS_PERM, IS_A_8BIT> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
int64_t num_bits, bool is_a_8bit) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(8, true)
|
||||
CALL_IF(4, false, false)
|
||||
CALL_IF(4, true, false)
|
||||
CALL_IF(8, false, false)
|
||||
CALL_IF(8, true, false)
|
||||
|
||||
CALL_IF(4, false, true)
|
||||
CALL_IF(8, false, true)
|
||||
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
", has_perm = ", has_perm);
|
||||
", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
|
||||
}
|
||||
|
||||
return out;
|
||||
|
||||
@ -11,17 +11,19 @@
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||
int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
|
||||
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
|
||||
const vllm::ScalarTypeId b_type_id, // B ScalarType id
|
||||
const vllm::ScalarTypeId c_type_id, // C ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 4;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 8;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include "marlin.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
@ -11,14 +13,16 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
template <long scalar_type_id>
|
||||
class MarlinScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
class MarlinScalarType<vllm::kFloat16.id()> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
using scalar_t4 = half2;
|
||||
using scalar_32bit_t = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
@ -27,6 +31,7 @@ class ScalarType<half> {
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
@ -44,18 +49,25 @@ class ScalarType<half> {
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
class MarlinScalarType<vllm::kBFloat16.id()> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
using scalar_t4 = nv_bfloat162;
|
||||
using scalar_32bit_t = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
@ -75,9 +87,63 @@ class ScalarType<nv_bfloat16> {
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType<vllm::kFE4M3fn.id()> {
|
||||
public:
|
||||
using scalar_t = __nv_fp8_e4m3;
|
||||
using scalar_t2 = __nv_fp8x2_e4m3;
|
||||
using scalar_t4 = __nv_fp8x4_e4m3;
|
||||
using scalar_32bit_t = __nv_fp8x4_e4m3;
|
||||
|
||||
using FragA = Vec<__nv_fp8x4_e4m3, 4>;
|
||||
using FragB = Vec<__nv_fp8x4_e4m3, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragZP = Vec<__nv_fp8x2_e4m3, 4>;
|
||||
|
||||
static __host__ __device__
|
||||
float2 inline num22float2(const __nv_fp8x2_e4m3 x) {
|
||||
return (float2)x;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType<vllm::kS8.id()> {
|
||||
public:
|
||||
using scalar_t = int8_t;
|
||||
using scalar_t2 = int16_t;
|
||||
using scalar_t4 = int32_t;
|
||||
using scalar_32bit_t = int32_t;
|
||||
|
||||
using FragA = Vec<int32_t, 4>;
|
||||
using FragB = Vec<int32_t, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragZP = Vec<int16_t, 4>;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
class MarlinScalarType2 {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<half> : public MarlinScalarType<vllm::kFloat16.id()> {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<nv_bfloat16>
|
||||
: public MarlinScalarType<vllm::kBFloat16.id()> {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<__nv_fp8_e4m3>
|
||||
: public MarlinScalarType<vllm::kFE4M3fn.id()> {};
|
||||
|
||||
template <>
|
||||
class MarlinScalarType2<int8_t> : public MarlinScalarType<vllm::kS8.id()> {};
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
|
||||
106
csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu
Normal file
106
csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu
Normal file
@ -0,0 +1,106 @@
|
||||
|
||||
|
||||
#include "marlin.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
// for only non-zp format (like gptq)
|
||||
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
|
||||
// qweight: (size_k * size_n // 8,)
|
||||
const int32_t* __restrict__ qweight,
|
||||
// output: same shape with qweight
|
||||
int32_t* __restrict__ output) {
|
||||
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
|
||||
int32_t new_val = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t i = 0; i < 8; i++) {
|
||||
int32_t single_val = val & 0xF;
|
||||
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
|
||||
new_val |= single_val << (i * 4);
|
||||
val >>= 4;
|
||||
}
|
||||
|
||||
output[blockIdx.x * 32 + threadIdx.x] = new_val;
|
||||
}
|
||||
|
||||
// for awq format only (with zp and with awq weight layout)
|
||||
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
|
||||
// AWQ qweight: (size_k, size_n // 8)
|
||||
const int32_t* __restrict__ qweight,
|
||||
// output: same shape with qweight
|
||||
int32_t* __restrict__ output,
|
||||
// AWQ zeros: (size_k // group_size, size_n // 8)
|
||||
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
|
||||
int32_t group_size) {
|
||||
int32_t val =
|
||||
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
|
||||
int32_t zero =
|
||||
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
|
||||
blockIdx.y];
|
||||
int32_t new_val = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t i = 0; i < 8; i++) {
|
||||
int32_t single_val = val & 0xF;
|
||||
int32_t single_zero = zero & 0xF;
|
||||
|
||||
single_val =
|
||||
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
|
||||
new_val |= single_val << (i * 4);
|
||||
val >>= 4;
|
||||
zero >>= 4;
|
||||
}
|
||||
|
||||
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
|
||||
}
|
||||
|
||||
torch::Tensor marlin_int4_fp8_preprocess(
|
||||
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
|
||||
bool inplace) {
|
||||
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
|
||||
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
|
||||
"qweight.dtype != torch.int32");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
|
||||
|
||||
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
|
||||
|
||||
if (!qzeros_or_none.has_value()) {
|
||||
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
|
||||
"qweight.numel() * 8 % 256 != 0");
|
||||
|
||||
int blocks = qweight.numel() * 8 / 256;
|
||||
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
|
||||
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
|
||||
} else {
|
||||
int32_t size_k = qweight.size(0);
|
||||
int32_t size_n = qweight.size(1) * 8;
|
||||
torch::Tensor qzeros = qzeros_or_none.value();
|
||||
|
||||
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
|
||||
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
|
||||
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
|
||||
"qweight.dtype != torch.int32");
|
||||
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
|
||||
"qzeros is not on the same device with qweight");
|
||||
|
||||
int32_t group_size = qweight.size(0) / qzeros.size(0);
|
||||
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
|
||||
"qweight.size(1) != qzeros.size(1)");
|
||||
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
|
||||
"qweight.size(0) % qzeros.size(0) != 0");
|
||||
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
|
||||
|
||||
dim3 blocks(size_k / 32, size_n / 8);
|
||||
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
|
||||
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
|
||||
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor? b_bias_or_none,"
|
||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"Tensor? b_bias_or_none,Tensor b_scales, "
|
||||
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
|
||||
"Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin repack from GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
||||
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
||||
"SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// awq_marlin repack from AWQ.
|
||||
ops.def(
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
"SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// preprocess W-int4A-fp8 weight for marlin kernel
|
||||
ops.def(
|
||||
"marlin_int4_fp8_preprocess(Tensor qweight, "
|
||||
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
|
||||
@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
|
||||
|
||||
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
|
||||
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
|
||||
- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod]
|
||||
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod]
|
||||
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
|
||||
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
|
||||
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
|
||||
|
||||
@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
MOE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": scalar_types.uint8b128,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
||||
# NVFP4
|
||||
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"c_type": [scalar_types.bfloat16],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
FUSED_MOE_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 2048, 128),
|
||||
@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases():
|
||||
m_list = [1, 123, 666]
|
||||
n_list = [128, 1024]
|
||||
k_list = [256, 2048]
|
||||
e_list = [4, 12]
|
||||
e_list = [5, 12]
|
||||
topk_list = [2, 3]
|
||||
ep_size_list = [1, 4]
|
||||
dtype_list = [torch.bfloat16]
|
||||
group_size_list = [-1, 32, 128]
|
||||
act_order_list = [True, False]
|
||||
quant_type_list = [
|
||||
scalar_types.float4_e2m1f,
|
||||
scalar_types.float8_e4m3fn,
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
is_k_full_list = [True, False]
|
||||
|
||||
all_combinations = itertools.product(
|
||||
MOE_MARLIN_QUANT_TEST_CONFIGS,
|
||||
m_list,
|
||||
n_list,
|
||||
k_list,
|
||||
e_list,
|
||||
topk_list,
|
||||
ep_size_list,
|
||||
dtype_list,
|
||||
group_size_list,
|
||||
act_order_list,
|
||||
quant_type_list,
|
||||
is_k_full_list,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
ep_size,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32]:
|
||||
return False
|
||||
if dtype == torch.float16 and group_size == 32:
|
||||
return False
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
if group_size > 0 and k % group_size != 0:
|
||||
return False
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size in (-1, k, n):
|
||||
return False
|
||||
if quant_type not in [scalar_types.uint4b8]:
|
||||
return False
|
||||
elif not is_k_full:
|
||||
if act_order and group_size in [-1, k, n]:
|
||||
return False
|
||||
if group_size in [k, n]:
|
||||
return False
|
||||
if not act_order and is_k_full:
|
||||
return False
|
||||
|
||||
return True
|
||||
return a_type.size_bits < 16 or a_type is c_type
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
if is_invalid(*case):
|
||||
cases.append(case)
|
||||
quant_test_config, m, n, k, _, _, _, act_order, *_ = case
|
||||
if act_order and not quant_test_config.get("support_act_order", False):
|
||||
continue
|
||||
|
||||
f16_types = [scalar_types.float16]
|
||||
inner_combinations = itertools.product(
|
||||
quant_test_config.get("a_type", f16_types),
|
||||
[quant_test_config["b_type"]],
|
||||
quant_test_config.get("c_type", f16_types),
|
||||
quant_test_config["group_blocks"],
|
||||
)
|
||||
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
):
|
||||
continue
|
||||
args = sub_case + (m, n, k) + case[4:]
|
||||
if is_invalid(*args):
|
||||
cases.append(args)
|
||||
return cases
|
||||
|
||||
|
||||
@ -571,6 +640,7 @@ class MarlinMoEWeightData:
|
||||
qweight: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
global_scale: torch.Tensor | None
|
||||
a_scales_factor: torch.Tensor | None
|
||||
g_idx: torch.Tensor | None
|
||||
zeros: torch.Tensor | None
|
||||
sort_indices: torch.Tensor | None
|
||||
@ -583,11 +653,20 @@ class MarlinMoEWeightData:
|
||||
group_size: int,
|
||||
act_order: bool | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
input_type: ScalarType = None,
|
||||
) -> "MarlinMoEWeightData":
|
||||
assert w.ndim == 3
|
||||
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
k = w.shape[-1]
|
||||
|
||||
if input_type == scalar_types.int8:
|
||||
input_dtype = torch.int8
|
||||
elif input_type == scalar_types.float8_e4m3fn:
|
||||
input_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
input_dtype = w.dtype
|
||||
|
||||
w_ref_l: list[torch.Tensor] = []
|
||||
qweight_l: list[torch.Tensor] = []
|
||||
scales_l: list[torch.Tensor] = []
|
||||
@ -601,11 +680,13 @@ class MarlinMoEWeightData:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, qweight, scales, global_scale = (
|
||||
rand_marlin_weight_nvfp4_like(w[i], group_size)
|
||||
rand_marlin_weight_nvfp4_like(
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
|
||||
w[i], group_size
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
global_scale = None
|
||||
|
||||
@ -615,13 +696,18 @@ class MarlinMoEWeightData:
|
||||
if global_scale is not None:
|
||||
global_scale_l.append(global_scale)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
elif has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size
|
||||
w[i].transpose(1, 0),
|
||||
quant_type,
|
||||
group_size,
|
||||
input_dtype=input_dtype,
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
@ -631,7 +717,12 @@ class MarlinMoEWeightData:
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
w[i].transpose(1, 0),
|
||||
quant_type,
|
||||
group_size,
|
||||
act_order,
|
||||
test_perm,
|
||||
input_dtype=input_dtype,
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
@ -652,11 +743,18 @@ class MarlinMoEWeightData:
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
marlin_bias = stack_and_dev(bias_l) if bias_l else None
|
||||
|
||||
a_scales_factor = None
|
||||
if input_type == scalar_types.int8 and group_size != -1:
|
||||
a_scales_factor = 1 / 4096 * scales.max().float()
|
||||
scales = scales / scales.max() * 4096
|
||||
scales = scales.round().to(torch.int16).view(w.dtype)
|
||||
|
||||
return MarlinMoEWeightData(
|
||||
w_ref=w_ref,
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
global_scale=global_scale,
|
||||
a_scales_factor=a_scales_factor,
|
||||
g_idx=g_idx,
|
||||
zeros=zeros,
|
||||
sort_indices=sort_indices,
|
||||
@ -666,28 +764,47 @@ class MarlinMoEWeightData:
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
||||
(
|
||||
"a_type, b_type, c_type, group_blocks,"
|
||||
"m, n, k, e, topk, ep_size, act_order, is_k_full"
|
||||
),
|
||||
marlin_moe_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
quant_type: ScalarType,
|
||||
is_k_full: bool,
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
ep_size,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed(1)
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
|
||||
if c_type == scalar_types.float16:
|
||||
dtype = torch.float16
|
||||
elif c_type == scalar_types.bfloat16:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("unsupported c_type")
|
||||
|
||||
if a_type == scalar_types.int8:
|
||||
a_dtype = torch.int8
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
a_dtype = dtype
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
@ -700,11 +817,19 @@ def test_fused_marlin_moe(
|
||||
e_map = None
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
w=w1,
|
||||
quant_type=b_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
input_type=a_type,
|
||||
)
|
||||
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
w=w2,
|
||||
quant_type=b_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
input_type=a_type,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
@ -712,8 +837,18 @@ def test_fused_marlin_moe(
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
torch_output = torch_experts(
|
||||
a,
|
||||
w1_data.w_ref,
|
||||
w2_data.w_ref,
|
||||
topk_weight=topk_weight,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
quant_dtype=a_dtype,
|
||||
per_act_token_quant=True,
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
@ -733,15 +868,18 @@ def test_fused_marlin_moe(
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
input_global_scale1=w1_data.a_scales_factor,
|
||||
input_global_scale2=w2_data.a_scales_factor,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
input_dtype=a_dtype,
|
||||
quant_type_id=b_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
|
||||
@ -5,6 +5,8 @@
|
||||
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like,
|
||||
)
|
||||
@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
||||
|
||||
MARLIN_REPACK_NK_FACTORS = [
|
||||
(4, 8),
|
||||
(7, 5),
|
||||
(13, 11),
|
||||
]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
@ -74,6 +84,64 @@ MNK_FACTORS = [
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": scalar_types.uint8b128,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
||||
# NVFP4
|
||||
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"c_type": [scalar_types.bfloat16],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
def test_marlin_int4_fp8_preprocess_without_zp():
|
||||
qweight_unpacked = torch.randint(
|
||||
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
||||
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
||||
|
||||
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
|
||||
|
||||
torch_res = torch.where(
|
||||
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
|
||||
)
|
||||
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
||||
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
||||
|
||||
assert (cuda_res == torch_res).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
def test_marlin_int4_fp8_preprocess_awq():
|
||||
group_size = 128
|
||||
|
||||
qweight_unpacked = torch.randint(
|
||||
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qzeros_unpacked = torch.randint(
|
||||
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
||||
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
||||
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
|
||||
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
|
||||
|
||||
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
|
||||
|
||||
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
|
||||
torch_res = qweight_unpacked - repeated_zp
|
||||
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
|
||||
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
||||
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
||||
|
||||
assert (cuda_res == torch_res).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16):
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
||||
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
||||
k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
n_factor, k_factor = nk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
group_size = 128
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
@ -109,6 +230,8 @@ def test_gptq_marlin_repack(
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if is_a_8bit:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
@ -133,23 +256,19 @@ def test_gptq_marlin_repack(
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
q_w_gptq,
|
||||
sort_indices,
|
||||
size_k,
|
||||
size_n,
|
||||
quant_type.size_bits,
|
||||
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -163,18 +282,15 @@ def test_gptq_marlin_repack(
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
||||
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
|
||||
n_factor, k_factor = nk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
group_size = 128
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
|
||||
torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq,
|
||||
size_k,
|
||||
size_n,
|
||||
quant_type.size_bits,
|
||||
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
def marlin_generate_valid_test_cases():
|
||||
all_combinations = itertools.product(
|
||||
DENSE_MARLIN_QUANT_TEST_CONFIGS,
|
||||
MNK_FACTORS,
|
||||
MARLIN_N_CHUNKS,
|
||||
MARLIN_K_CHUNKS,
|
||||
ACT_ORDER_OPTS,
|
||||
K_FULL_OPTS,
|
||||
USE_ATOMIC_ADD_OPTS,
|
||||
USE_FP32_REDUCE_OPTS,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
if use_atomic_add:
|
||||
if use_fp32_reduce:
|
||||
return False
|
||||
if (
|
||||
c_type == scalar_types.bfloat16
|
||||
and torch.cuda.get_device_capability()[0] < 9
|
||||
):
|
||||
return False
|
||||
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
if group_size > 0 and size_k % group_size != 0:
|
||||
return False
|
||||
|
||||
if act_order and group_size in [-1, size_k]:
|
||||
return False
|
||||
if group_size == size_k:
|
||||
return False
|
||||
if not act_order and is_k_full:
|
||||
return False
|
||||
|
||||
return a_type.size_bits < 16 or a_type is c_type
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
|
||||
size_m = mnk_factors[0]
|
||||
size_n = mnk_factors[1] * n_chunk
|
||||
size_k = mnk_factors[2] * k_chunk
|
||||
|
||||
if act_order and not quant_test_config.get("support_act_order", False):
|
||||
continue
|
||||
|
||||
f16_types = [scalar_types.float16, scalar_types.bfloat16]
|
||||
inner_combinations = itertools.product(
|
||||
quant_test_config.get("a_type", f16_types),
|
||||
[quant_test_config["b_type"]],
|
||||
quant_test_config.get("c_type", f16_types),
|
||||
quant_test_config["group_blocks"],
|
||||
)
|
||||
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
):
|
||||
continue
|
||||
args = sub_case + (size_m, size_n, size_k) + case[4:]
|
||||
if is_invalid(*args):
|
||||
cases.append(args)
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
||||
@pytest.mark.parametrize(
|
||||
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
(
|
||||
"a_type, b_type, c_type, group_blocks,"
|
||||
"size_m, size_n, size_k, act_order, is_k_full,"
|
||||
"use_atomic_add, use_fp32_reduce"
|
||||
),
|
||||
marlin_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
if c_type == scalar_types.float16:
|
||||
dtype = torch.float16
|
||||
elif c_type == scalar_types.bfloat16:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("unsupported c_type")
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
if a_type == scalar_types.int8:
|
||||
a_dtype = torch.int8
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
a_dtype = dtype
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype)
|
||||
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32] or act_order:
|
||||
return
|
||||
if group_size == 32 and dtype == torch.float16:
|
||||
return
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
if b_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
|
||||
b_weight.T, group_size
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
|
||||
b_weight.T, group_size
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
marlin_s2 = None
|
||||
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
|
||||
elif b_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
elif has_zp:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
b_weight, b_type, group_size, input_dtype=a_dtype
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
b_weight, b_type, group_size, act_order, input_dtype=a_dtype
|
||||
)
|
||||
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
False,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
if a_type == scalar_types.int8:
|
||||
a_input, a_scales = per_token_quant_int8(a_input)
|
||||
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
||||
a_input_ref = a_input_ref.to(dtype)
|
||||
|
||||
if group_size != -1:
|
||||
a_scales = a_scales / 4096 * marlin_s.max()
|
||||
a_scales = a_scales.float()
|
||||
marlin_s = marlin_s / marlin_s.max() * 4096
|
||||
marlin_s = marlin_s.round().to(torch.int16).view(dtype)
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
|
||||
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
||||
a_input_ref = a_input_ref.to(dtype)
|
||||
else:
|
||||
assert a_type.size_bits == 16
|
||||
a_input_ref = a_input
|
||||
a_scales = None
|
||||
|
||||
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
output,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
a_scales,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
b_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
@ -352,12 +527,9 @@ def test_gptq_marlin_gemm(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
output_ref = torch.matmul(a_input_ref, w_ref)
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@ -507,6 +679,7 @@ def test_hqq_marlin_gemm(
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input():
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
marlin_bias,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
|
||||
@ -846,6 +846,13 @@ def torch_experts(
|
||||
or (expert_map is not None and global_num_experts == expert_map.shape[0])
|
||||
)
|
||||
|
||||
if quant_dtype in [torch.float16, torch.bfloat16]:
|
||||
quant_dtype = None
|
||||
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
|
||||
if quant_input_only:
|
||||
assert a1_scale is None and a2_scale is None
|
||||
assert per_act_token_quant
|
||||
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
@ -863,6 +870,9 @@ def torch_experts(
|
||||
a, a1_scale, quant_dtype, per_act_token_quant, block_shape
|
||||
)
|
||||
|
||||
if quant_input_only:
|
||||
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
topk_ids = topk_ids.view(-1)
|
||||
@ -882,6 +892,14 @@ def torch_experts(
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
|
||||
elif quant_input_only:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, tmp2_scale = moe_kernel_quantize_input(
|
||||
tmp2, None, quant_dtype, per_act_token_quant
|
||||
)
|
||||
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
elif block_shape is not None:
|
||||
# block quantized
|
||||
assert (
|
||||
|
||||
@ -554,6 +554,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
b_q_weight: torch.Tensor,
|
||||
b_bias: torch.Tensor | None,
|
||||
b_scales: torch.Tensor,
|
||||
a_scales: torch.Tensor | None,
|
||||
global_scale: torch.Tensor | None,
|
||||
b_zeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
@ -568,7 +569,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
dtype = a.dtype
|
||||
if dtype not in [torch.half, torch.bfloat16]:
|
||||
dtype = b_scales.dtype
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
|
||||
|
||||
@register_fake("_C::awq_dequantize")
|
||||
def _awq_dequantize_fake(
|
||||
@ -1167,8 +1171,11 @@ def gptq_marlin_repack(
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
||||
return torch.ops._C.gptq_marlin_repack(
|
||||
b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "gptq_marlin_repack"):
|
||||
@ -1180,6 +1187,7 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
|
||||
size_k: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
pack_factor = 32 // num_bits
|
||||
marlin_tile_size = 16
|
||||
@ -1192,9 +1200,15 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
|
||||
|
||||
# awq_marlin
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
b_q_weight: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
return torch.ops._C.awq_marlin_repack(
|
||||
b_q_weight, size_k, size_n, num_bits, is_a_8bit
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "awq_marlin_repack"):
|
||||
@ -1205,6 +1219,7 @@ if hasattr(torch.ops._C, "awq_marlin_repack"):
|
||||
size_k: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
pack_factor = 32 // num_bits
|
||||
marlin_tile_size = 16
|
||||
@ -1221,6 +1236,7 @@ def gptq_marlin_moe_repack(
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
@ -1231,7 +1247,7 @@ def gptq_marlin_moe_repack(
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops._C.gptq_marlin_repack(
|
||||
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
||||
b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
|
||||
)
|
||||
return output
|
||||
|
||||
@ -1242,6 +1258,7 @@ def awq_marlin_moe_repack(
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
@ -1252,17 +1269,26 @@ def awq_marlin_moe_repack(
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops._C.awq_marlin_repack(
|
||||
b_q_weight[e], size_k, size_n, num_bits
|
||||
b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def marlin_int4_fp8_preprocess(
|
||||
qweight: torch.Tensor,
|
||||
qzeros_or_none: torch.Tensor | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
|
||||
|
||||
|
||||
def gptq_marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_bias: torch.Tensor | None,
|
||||
b_scales: torch.Tensor,
|
||||
a_scales: torch.Tensor | None,
|
||||
global_scale: torch.Tensor | None,
|
||||
b_zeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
@ -1283,6 +1309,7 @@ def gptq_marlin_gemm(
|
||||
b_q_weight,
|
||||
b_bias,
|
||||
b_scales,
|
||||
a_scales,
|
||||
global_scale,
|
||||
b_zeros,
|
||||
g_idx,
|
||||
@ -1600,7 +1627,7 @@ def allspark_repack_weight(
|
||||
if use asymmetric quantization, has_zp = True.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] :
|
||||
rearranged weight, scale, and optionally zero_point.
|
||||
"""
|
||||
K = qweight.shape[0]
|
||||
@ -1683,7 +1710,7 @@ def scaled_int8_quant(
|
||||
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
|
||||
"""
|
||||
output = torch.empty_like(input, dtype=torch.int8)
|
||||
if scale is not None:
|
||||
@ -2004,6 +2031,7 @@ def moe_wna16_marlin_gemm(
|
||||
b_qweight: torch.Tensor,
|
||||
b_bias: torch.Tensor | None,
|
||||
b_scales: torch.Tensor,
|
||||
a_scales: torch.Tensor | None,
|
||||
global_scale: torch.Tensor | None,
|
||||
b_qzeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
@ -2025,6 +2053,9 @@ def moe_wna16_marlin_gemm(
|
||||
use_atomic_add: bool,
|
||||
use_fp32_reduce: bool,
|
||||
is_zp_float: bool,
|
||||
thread_k: int = -1,
|
||||
thread_n: int = -1,
|
||||
blocks_per_sm: int = -1,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._moe_C.moe_wna16_marlin_gemm(
|
||||
input,
|
||||
@ -2032,6 +2063,7 @@ def moe_wna16_marlin_gemm(
|
||||
b_qweight,
|
||||
b_bias,
|
||||
b_scales,
|
||||
a_scales,
|
||||
global_scale,
|
||||
b_qzeros,
|
||||
g_idx,
|
||||
@ -2053,6 +2085,9 @@ def moe_wna16_marlin_gemm(
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
is_zp_float,
|
||||
thread_k,
|
||||
thread_n,
|
||||
blocks_per_sm,
|
||||
)
|
||||
|
||||
|
||||
@ -2088,7 +2123,10 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
|
||||
input: torch.Tensor,
|
||||
output: torch.Tensor | None,
|
||||
b_qweight: torch.Tensor,
|
||||
b_bias: torch.Tensor | None,
|
||||
b_scales: torch.Tensor,
|
||||
a_scales: torch.Tensor | None,
|
||||
global_scale: torch.Tensor | None,
|
||||
b_qzeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
perm: torch.Tensor | None,
|
||||
@ -2109,7 +2147,7 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
|
||||
use_atomic_add: bool,
|
||||
use_fp32_reduce: bool,
|
||||
is_zp_float: bool,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
return torch.empty(
|
||||
(size_m * top_k, size_n), dtype=input.dtype, device=input.device
|
||||
)
|
||||
@ -2583,7 +2621,7 @@ def onednn_scaled_int8_quant(
|
||||
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
|
||||
"""
|
||||
output = torch.empty_like(input, dtype=torch.int8)
|
||||
token_num = input.numel() // input.shape[-1]
|
||||
|
||||
@ -145,6 +145,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
|
||||
VLLM_MXFP4_USE_MARLIN: bool | None = None
|
||||
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
@ -1122,6 +1123,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
|
||||
os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
|
||||
),
|
||||
# The activation dtype for marlin kernel
|
||||
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
|
||||
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
|
||||
),
|
||||
# Whether to turn on the outlines cache for V1
|
||||
# This cache is unbounded and on disk, so it's not safe to use in
|
||||
# an environment with potentially malicious users.
|
||||
|
||||
@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_intermediate_size,
|
||||
maybe_warn_marlin_atomic_add,
|
||||
marlin_quant_input,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
@ -65,6 +65,8 @@ def _fused_marlin_moe(
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
] = default_activation_func,
|
||||
input_global_scale1: torch.Tensor | None = None,
|
||||
input_global_scale2: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
@ -77,6 +79,7 @@ def _fused_marlin_moe(
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
output: torch.Tensor | None = None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 2
|
||||
@ -106,18 +109,22 @@ def _fused_marlin_moe(
|
||||
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
|
||||
|
||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
a_scales1 = None
|
||||
gate_up_input = hidden_states
|
||||
if input_dtype == torch.int8:
|
||||
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
|
||||
if input_global_scale1 is not None:
|
||||
a_scales1 = a_scales1 * input_global_scale1
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
|
||||
|
||||
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
|
||||
hidden_states,
|
||||
gate_up_input,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
bias1,
|
||||
w1_scale,
|
||||
a_scales1,
|
||||
global_scale1,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
@ -136,7 +143,7 @@ def _fused_marlin_moe(
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
@ -151,12 +158,25 @@ def _fused_marlin_moe(
|
||||
if expert_map is not None:
|
||||
output.zero_()
|
||||
|
||||
a_scales2 = None
|
||||
if input_dtype == torch.int8:
|
||||
intermediate_cache2, a_scales2 = marlin_quant_input(
|
||||
intermediate_cache2, input_dtype
|
||||
)
|
||||
if input_global_scale2 is not None:
|
||||
a_scales2 = a_scales2 * input_global_scale2
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
intermediate_cache2, a_scales2 = marlin_quant_input(
|
||||
intermediate_cache2, input_dtype
|
||||
)
|
||||
|
||||
output = ops.moe_wna16_marlin_gemm(
|
||||
intermediate_cache2,
|
||||
output,
|
||||
w2,
|
||||
bias2,
|
||||
w2_scale,
|
||||
a_scales2,
|
||||
global_scale2,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
@ -175,7 +195,7 @@ def _fused_marlin_moe(
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
@ -203,6 +223,8 @@ def fused_marlin_moe(
|
||||
] = default_activation_func,
|
||||
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
input_global_scale1: torch.Tensor | None = None,
|
||||
input_global_scale2: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
@ -216,6 +238,7 @@ def fused_marlin_moe(
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
output: torch.Tensor | None = None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -287,6 +310,9 @@ def fused_marlin_moe(
|
||||
if M * topk / E / block_size_m < 0.9:
|
||||
break
|
||||
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
block_size_m = max(block_size_m, 16)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
@ -313,6 +339,8 @@ def fused_marlin_moe(
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
activation=activation,
|
||||
activation_func=activation_func,
|
||||
input_global_scale1=input_global_scale1,
|
||||
input_global_scale2=input_global_scale2,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
@ -325,6 +353,7 @@ def fused_marlin_moe(
|
||||
intermediate_cache13=intermediate_cache13,
|
||||
intermediate_cache2=intermediate_cache2,
|
||||
output=None,
|
||||
input_dtype=input_dtype,
|
||||
is_k_full=is_k_full,
|
||||
).view(-1, topk, K)
|
||||
|
||||
|
||||
@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig,
|
||||
AWQMarlinLinearMethod,
|
||||
AWQMoEMethod,
|
||||
AWQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
quant_args_marlin = AWQMarlinConfig(
|
||||
@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
return AWQMarlinMoEMethod(quant_args_marlin, layer.moe)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
config = {
|
||||
|
||||
@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig):
|
||||
return AWQLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
# Lazy import to avoid circular import.
|
||||
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
|
||||
from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .utils.marlin_utils import check_moe_marlin_supports_layer
|
||||
|
||||
@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig):
|
||||
awq_marlin_config = AWQMarlinConfig.from_config(
|
||||
marlin_compatible_config_dict
|
||||
)
|
||||
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
|
||||
return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
|
||||
@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_marlin_supports_layer,
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: scalar_types.uint4,
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return AWQConfig.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
quant_method = AWQMarlinLinearMethod(self)
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
return AWQMoEMethod(self, layer.moe_config)
|
||||
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.input_dtype = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
layer.num_groups = num_groups
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
|
||||
layer.scales.data = layer.scales.data * 512
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "qweight", marlin_qweight)
|
||||
|
||||
@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups > 1:
|
||||
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
|
||||
)
|
||||
|
||||
replace_parameter(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "qzeros", marlin_zp)
|
||||
|
||||
@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=self.input_dtype,
|
||||
)
|
||||
|
||||
|
||||
class AWQMoEMethod(FusedMoEMethodBase):
|
||||
class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: AWQMarlinConfig,
|
||||
@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.weight_bits != 4:
|
||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||
raise ValueError("AWQMarlinMoEMethod only supports 4bit now.")
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.input_dtype = None
|
||||
self.use_marlin = True
|
||||
|
||||
def create_weights(
|
||||
@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.input_dtype = self.input_dtype
|
||||
extra_weight_attrs.update(
|
||||
{
|
||||
"is_transposed": True,
|
||||
@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
num_groups_w13 = hidden_size // self.quant_config.group_size
|
||||
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
|
||||
layer.num_groups_w13 = num_groups_w13
|
||||
layer.num_groups_w2 = num_groups_w2
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(
|
||||
layer.w13_qweight.view(-1, layer.w13_qweight.size(2)),
|
||||
layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)),
|
||||
inplace=True,
|
||||
)
|
||||
ops.marlin_int4_fp8_preprocess(
|
||||
layer.w2_qweight.view(-1, layer.w2_qweight.size(2)),
|
||||
layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)),
|
||||
inplace=True,
|
||||
)
|
||||
layer.w13_scales.data = layer.w13_scales.data * 512
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.w13_qweight.shape[1],
|
||||
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
|
||||
@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.w2_qweight.shape[1],
|
||||
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
|
||||
@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
|
||||
@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
marlin_w13_zp = moe_awq_to_marlin_zero_points(
|
||||
@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.w13_qzeros.shape[1],
|
||||
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
|
||||
|
||||
@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.w2_qzeros.shape[1],
|
||||
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
||||
|
||||
@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.input_dtype,
|
||||
)
|
||||
|
||||
@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
|
||||
return CompressedTensorsMoEMethod.get_moe_method(
|
||||
self, layer, layer_name=prefix
|
||||
)
|
||||
return None
|
||||
|
||||
def _add_fused_moe_to_target_scheme_map(self):
|
||||
@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
format: str | None = None,
|
||||
layer_name: str | None = None,
|
||||
) -> "CompressedTensorsScheme":
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
format = format if format is not None else self.quant_format
|
||||
@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant, input_quant=input_quant, format=format
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
format=format,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
|
||||
@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
)
|
||||
@ -101,7 +103,7 @@ __all__ = [
|
||||
"CompressedTensorsW8A8Int8MoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsW4A4Nvfp4MoeMethod",
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod",
|
||||
"CompressedTensorsW4A8Int8MoEMethod",
|
||||
]
|
||||
|
||||
@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
layer_name: str,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# FusedMoE was made by combining multiple Linears so need to
|
||||
# make sure quantization config for Linear can target it
|
||||
quant_config._add_fused_moe_to_target_scheme_map()
|
||||
unfused_names = [
|
||||
prefix + proj_name
|
||||
layer_name + proj_name
|
||||
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
||||
]
|
||||
# TODO: refactor this to use expert_mapping and check all layer numbers
|
||||
@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
"WNA16MoE is not supported with actorder=group/dynamic."
|
||||
)
|
||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
|
||||
return CompressedTensorsWNA16MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
)
|
||||
else:
|
||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||
quant_config, layer.moe_config
|
||||
quant_config, layer.moe_config, layer_name
|
||||
)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
|
||||
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
|
||||
elif (
|
||||
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
|
||||
):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config)
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
)
|
||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config)
|
||||
return CompressedTensorsW8A8Int8MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
)
|
||||
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config)
|
||||
return CompressedTensorsW4A8Int8MoEMethod(
|
||||
quant_config, layer.moe_config, layer_name
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support,
|
||||
)
|
||||
@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.layer_name = layer_name
|
||||
self.marlin_input_dtype = (
|
||||
get_marlin_input_dtype(layer_name) if self.use_marlin else None
|
||||
)
|
||||
self.flashinfer_moe_backend = None
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
" for CompressedTensorsW4A4Nvfp4MoeMethod."
|
||||
" for CompressedTensorsW4A4Nvfp4MoEMethod."
|
||||
)
|
||||
elif self.use_marlin:
|
||||
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
|
||||
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
|
||||
else:
|
||||
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
|
||||
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
|
||||
return
|
||||
# w13
|
||||
if (
|
||||
@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||
):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
|
||||
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||
assert expert_map is None, (
|
||||
"Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4Nvfp4MoeMethod."
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod."
|
||||
)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
or self.is_fp8_w8a8_sm100
|
||||
)
|
||||
self.disable_expert_map = False
|
||||
self.layer_name = layer_name
|
||||
self.marlin_input_dtype = (
|
||||
get_marlin_input_dtype(layer_name) if self.use_marlin else None
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
|
||||
elif self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
self.strategy = config.strategy
|
||||
self.group_size = config.group_size
|
||||
self.actorder = config.actorder
|
||||
self.layer_name = layer_name
|
||||
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
|
||||
assert config.symmetric, "Only symmetric quantization is supported for MoE"
|
||||
|
||||
if not (
|
||||
@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
layer.num_groups_w13 = num_groups_w13
|
||||
layer.num_groups_w2 = num_groups_w2
|
||||
|
||||
w13_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
is_a_8bit = (
|
||||
self.marlin_input_dtype is not None
|
||||
and self.marlin_input_dtype.itemsize == 1
|
||||
)
|
||||
|
||||
if self.marlin_input_dtype == torch.float8_e4m3fn:
|
||||
# NOTE: for non-zp quantization format only
|
||||
ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True)
|
||||
ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512
|
||||
|
||||
# when running models with grouped act order,
|
||||
# resort to g_idx values provided in checkpoint
|
||||
@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w13_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
|
||||
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w2_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
|
||||
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_weight_scale,
|
||||
size_k=layer.w13_weight_packed.shape[2],
|
||||
size_n=layer.w13_weight_scale.shape[2],
|
||||
group_size=self.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_weight_scale,
|
||||
size_k=layer.w2_weight_scale.shape[1]
|
||||
* (self.group_size if self.group_size != -1 else self.packed_factor),
|
||||
size_n=layer.w2_weight_scale.shape[2],
|
||||
group_size=self.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
|
||||
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
|
||||
@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.quant_config = quant_config
|
||||
@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
|
||||
@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
layer_name: str | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
self.layer_name = layer_name
|
||||
|
||||
if self.group_size == -1 and self.strategy != "channel":
|
||||
raise ValueError(
|
||||
@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
if isinstance(kernel_type, MarlinLinearKernel):
|
||||
input_dtype = get_marlin_input_dtype(self.layer_name)
|
||||
if input_dtype is not None:
|
||||
mp_linear_kernel_config.act_type = input_dtype
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
|
||||
@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_weight_tensor_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig):
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
quant_method = Fp8LinearMethod(self)
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig):
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return Fp8MoEMethod(self, layer)
|
||||
moe_quant_method = Fp8MoEMethod(self, layer)
|
||||
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first)
|
||||
prepare_fp8_layer_for_marlin(
|
||||
layer, size_k_first, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
return
|
||||
@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.block_quant, layer.moe_parallel_config
|
||||
)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight.data = w13_weight.data
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
|
||||
@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
marlin_permute_bias,
|
||||
@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod)
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
||||
moe_quant_method = get_moe_quant_method(
|
||||
self, layer, prefix, GPTQMarlinMoEMethod
|
||||
)
|
||||
if moe_quant_method is None:
|
||||
return None
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
|
||||
quant_method = get_linear_quant_method(
|
||||
self, layer, prefix, GPTQMarlinLinearMethod
|
||||
)
|
||||
if quant_method is None:
|
||||
return None
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.input_dtype = None
|
||||
self.quant_type = self.quant_config.quant_type
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(
|
||||
@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
input_dtype = self.input_dtype
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
act_type=params_dtype if input_dtype is None else input_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act,
|
||||
@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
self.input_dtype = None
|
||||
self.use_marlin = True
|
||||
|
||||
def create_weights(
|
||||
@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.input_dtype = self.input_dtype
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert self.quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
layer.num_groups_w13 = scales_size13
|
||||
layer.num_groups_w2 = scales_size2
|
||||
|
||||
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert self.quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
|
||||
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
|
||||
layer.w13_scales.data = layer.w13_scales.data * 512
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
)
|
||||
|
||||
@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
bias,
|
||||
scales,
|
||||
None,
|
||||
None,
|
||||
zeros,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert c.weight_type == scalar_types.uint4b8, (
|
||||
"W8A8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if c.act_type == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
|
||||
getattr(layer, self.w_s_name).data = (
|
||||
getattr(layer, self.w_s_name).data * 512
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
return x
|
||||
|
||||
@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
if c.group_size == -1:
|
||||
num_groups = 1
|
||||
else:
|
||||
num_groups = c.partition_weight_shape[0] // c.group_size
|
||||
|
||||
if c.act_type == torch.int8 and num_groups > 1:
|
||||
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
|
||||
layer.register_parameter(
|
||||
"input_global_scale",
|
||||
torch.nn.Parameter(input_global_scale, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
layer.input_global_scale = None
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=c.act_type,
|
||||
)
|
||||
|
||||
@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
is_fp4_marlin_supported,
|
||||
@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
|
||||
# now, the layer is quantized, handle it here
|
||||
if isinstance(layer, LinearBase):
|
||||
return self.LinearMethodCls(self)
|
||||
quant_method = self.LinearMethodCls(self)
|
||||
if getattr(quant_method, "backend", "") == "marlin":
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return self.FusedMoEMethodCls(quant_config=self, layer=layer)
|
||||
quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
|
||||
if getattr(quant_method, "backend", "") == "marlin":
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
|
||||
return None
|
||||
|
||||
@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.marlin_input_dtype = None
|
||||
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
output_dtype = x.dtype
|
||||
@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.marlin_input_dtype = None
|
||||
self.flashinfer_moe_backend = None
|
||||
if self.allow_flashinfer:
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
elif self.allow_flashinfer:
|
||||
|
||||
@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
|
||||
if current_platform.is_xpu():
|
||||
return IpexMxfp4MoEMethod(layer.moe_config)
|
||||
else:
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
quant_method = Mxfp4MoEMethod(layer.moe_config)
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
# TODO: Add support for MXFP4 Attention.
|
||||
logger.debug_once(
|
||||
@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
|
||||
elif (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
|
||||
@ -9,6 +9,11 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
@ -286,10 +291,10 @@ def get_scale_perms():
|
||||
|
||||
|
||||
def marlin_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
|
||||
) -> torch.Tensor:
|
||||
scale_perm, scale_perm_single = get_scale_perms()
|
||||
if group_size < size_k and group_size != -1:
|
||||
if group_size < size_k and group_size != -1 and not is_a_8bit:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
||||
return s.reshape(*origin_shape).contiguous()
|
||||
|
||||
|
||||
def marlin_act_int8_process_scales(s: torch.Tensor):
|
||||
a_scales_scale_factor = 1 / 4096 * s.max().float()
|
||||
s = s / s.max() * 4096
|
||||
s = s.round().to(torch.int16).view(s.dtype)
|
||||
return s, a_scales_scale_factor
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
group_size: int,
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
|
||||
):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty(
|
||||
@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
|
||||
)
|
||||
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
|
||||
return output
|
||||
|
||||
|
||||
def marlin_zero_points(
|
||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
|
||||
) -> torch.Tensor:
|
||||
# Permute zero-points in a similar way to scales, but do not use the
|
||||
# "single" permutation, since zero-points are applied on every MMA
|
||||
@ -339,7 +348,8 @@ def marlin_zero_points(
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
if not is_a_8bit:
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
zp = pack_cols(zp, num_bits, size_k, size_n)
|
||||
|
||||
@ -347,7 +357,11 @@ def marlin_zero_points(
|
||||
|
||||
|
||||
def awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
q_zp_packed: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# AWQ zero-points are quantized and packed on the column dim.
|
||||
# In addition, the values are permuted based on dequantizer.
|
||||
@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
|
||||
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
||||
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
|
||||
return marlin_zp
|
||||
|
||||
|
||||
def moe_awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
q_zp_packed: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
):
|
||||
num_experts = q_zp_packed.shape[0]
|
||||
output = torch.empty(
|
||||
@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
|
||||
dtype=q_zp_packed.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
||||
output[e] = awq_to_marlin_zero_points(
|
||||
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
|
||||
return True
|
||||
|
||||
|
||||
_quant_fp8_method: QuantFP8 | None = None
|
||||
|
||||
|
||||
def get__quant_fp8_method() -> QuantFP8:
|
||||
global _quant_fp8_method
|
||||
if _quant_fp8_method is None:
|
||||
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
|
||||
return _quant_fp8_method
|
||||
|
||||
|
||||
def get_marlin_input_dtype(prefix):
|
||||
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
|
||||
return
|
||||
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
|
||||
return torch.int8
|
||||
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
|
||||
if not current_platform.is_device_capability(
|
||||
89
|
||||
) and not current_platform.is_device_capability(120):
|
||||
raise ValueError(
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device "
|
||||
"(It is slower than Marlin W4A16 on other devices). "
|
||||
"You can consider using W4A8-INT8 instead"
|
||||
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
|
||||
)
|
||||
|
||||
_ = get__quant_fp8_method()
|
||||
return torch.float8_e4m3fn
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if quant_dtype == torch.int8:
|
||||
return per_token_quant_int8(x)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
return get__quant_fp8_method()(x)
|
||||
else:
|
||||
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
input_global_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
assert wtype == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
assert wtype == scalar_types.uint4b8, (
|
||||
"INT8 weight + FP8 activation is not supported."
|
||||
)
|
||||
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
a_scales,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_global_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
a_scales,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_global_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
a_scales,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
marlin_quant_input,
|
||||
should_use_atomic_add_reduce,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales):
|
||||
# convert to half first, we would convert to fp8 later
|
||||
marlin_scales = marlin_scales.to(torch.half)
|
||||
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1
|
||||
)
|
||||
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1
|
||||
@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales):
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def mxfp4_marlin_process_scales(marlin_scales):
|
||||
# 8 is the number of scale number using by one thread
|
||||
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
|
||||
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
|
||||
marlin_scales.size(0) * 2, -1
|
||||
)
|
||||
|
||||
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
|
||||
# fit the layout of fp8 dequantization
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1
|
||||
)
|
||||
if input_dtype is None or input_dtype.itemsize == 2:
|
||||
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
|
||||
marlin_scales.size(0), -1
|
||||
)
|
||||
|
||||
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
|
||||
if input_dtype == torch.float8_e4m3fn:
|
||||
marlin_scales = marlin_scales.view(torch.uint8)
|
||||
assert marlin_scales.max() <= 249
|
||||
# exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6
|
||||
marlin_scales = marlin_scales + 6
|
||||
marlin_scales = marlin_scales.view(torch.float8_e8m0fnu)
|
||||
return marlin_scales
|
||||
|
||||
|
||||
@ -99,6 +96,7 @@ def apply_fp4_marlin_linear(
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: torch.Tensor | None = None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
# For GPUs that lack FP4 hardware support, we can leverage the
|
||||
@ -111,12 +109,24 @@ def apply_fp4_marlin_linear(
|
||||
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
|
||||
)
|
||||
|
||||
inputs = reshaped_x
|
||||
a_scales = None
|
||||
is_nvfp4 = weight_scale_2 is not None
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
if is_nvfp4:
|
||||
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
|
||||
elif input_dtype != torch.float8_e4m3fn:
|
||||
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
|
||||
|
||||
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
a=inputs,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
a_scales=a_scales,
|
||||
global_scale=weight_scale_2,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
@ -133,7 +143,9 @@ def apply_fp4_marlin_linear(
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
def prepare_fp4_layer_for_marlin(
|
||||
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
qweight = layer.weight.view(torch.int32).T.contiguous()
|
||||
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=4,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
|
||||
weight_scale = weight_scale.to(param_dtype)
|
||||
weight_scale = marlin_permute_scales(
|
||||
s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size
|
||||
s=weight_scale,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
if is_nvfp4:
|
||||
@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
|
||||
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
|
||||
else:
|
||||
weight_scale = mxfp4_marlin_process_scales(weight_scale)
|
||||
weight_scale = mxfp4_marlin_process_scales(
|
||||
weight_scale, input_dtype=input_dtype
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
return
|
||||
|
||||
|
||||
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
def prepare_moe_fp4_layer_for_marlin(
|
||||
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP4 computation but "
|
||||
"FP4 quantization is being used. Weight-only FP4 compression will "
|
||||
@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
param_dtype = layer.params_dtype
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
qweight = weight[i].view(torch.int32).T.contiguous()
|
||||
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4
|
||||
b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
tensor_list.append(marlin_qweight)
|
||||
|
||||
@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
scale = scales[i].T
|
||||
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scale, size_k=size_k, size_n=size_n, group_size=group_size
|
||||
s=scale,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if is_nvfp4:
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
else:
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales = mxfp4_marlin_process_scales(
|
||||
marlin_scales, input_dtype=input_dtype
|
||||
)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
setattr(layer, name, bias)
|
||||
|
||||
|
||||
def rand_marlin_weight_nvfp4_like(weight, group_size):
|
||||
def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
|
||||
assert not is_a_8bit, "NVFP4 weight + INT8/FP8 activation is not supported."
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size
|
||||
s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
|
||||
|
||||
def rand_marlin_weight_mxfp4_like(weight, group_size):
|
||||
def rand_marlin_weight_mxfp4_like(weight, group_size, input_dtype=None):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
if is_a_8bit:
|
||||
assert input_dtype == torch.float8_e4m3fn, (
|
||||
"MXFP4 weight + INT8 activation is not supported."
|
||||
)
|
||||
|
||||
assert group_size > 0
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
scales = torch.randint(
|
||||
100,
|
||||
125,
|
||||
110,
|
||||
120,
|
||||
(size_n, size_k // group_size),
|
||||
dtype=torch.uint8,
|
||||
device=weight.device,
|
||||
@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size):
|
||||
).view(size_n, size_k)
|
||||
weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
||||
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
fp4_weight = fp4_weight.view(torch.int32).T.contiguous()
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
b_q_weight=fp4_weight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size
|
||||
s=scales.T.to(weight.dtype),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales = mxfp4_marlin_process_scales(marlin_scales, input_dtype=input_dtype)
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
marlin_quant_input,
|
||||
should_use_atomic_add_reduce,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@ -45,6 +46,7 @@ def apply_fp8_marlin_linear(
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: torch.Tensor | None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
@ -57,12 +59,21 @@ def apply_fp8_marlin_linear(
|
||||
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
|
||||
)
|
||||
|
||||
inputs = reshaped_x
|
||||
a_scales = None
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
if input_dtype != torch.float8_e4m3fn:
|
||||
raise RuntimeError("FP8 weight + INT8 activation is not supported.")
|
||||
|
||||
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
b_bias=bias,
|
||||
b_scales=weight_scale,
|
||||
a_scales=a_scales,
|
||||
global_scale=None,
|
||||
b_zeros=None,
|
||||
g_idx=None,
|
||||
@ -80,7 +91,9 @@ def apply_fp8_marlin_linear(
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(
|
||||
layer: torch.nn.Module, size_k_first: bool = True
|
||||
layer: torch.nn.Module,
|
||||
size_k_first: bool = True,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin(
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
|
||||
)
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
if input_dtype != torch.float8_e4m3fn:
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin(
|
||||
|
||||
|
||||
def prepare_moe_fp8_layer_for_marlin(
|
||||
layer: torch.nn.Module, size_k_first: bool = True
|
||||
layer: torch.nn.Module,
|
||||
size_k_first: bool = True,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = fp8_fused_exponent_bias_into_scales(scales)
|
||||
if input_dtype != torch.float8_e4m3fn:
|
||||
scales = fp8_fused_exponent_bias_into_scales(scales)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
@ -318,7 +335,11 @@ def pack_fp8_to_int32(
|
||||
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
||||
|
||||
|
||||
def marlin_quant_fp8_torch(weight, group_size):
|
||||
def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
if is_a_8bit:
|
||||
assert input_dtype == torch.float8_e4m3fn
|
||||
|
||||
size_n, size_k = weight.shape
|
||||
device = weight.device
|
||||
|
||||
@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size):
|
||||
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||
|
||||
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_weight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
|
||||
s=scales.T,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
||||
from .quant_utils import (
|
||||
@ -29,13 +30,19 @@ class MarlinWorkspace:
|
||||
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
def marlin_permute_weights(
|
||||
q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
if is_a_8bit:
|
||||
# Permute weights to 32x32 marlin tiles
|
||||
q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile))
|
||||
else:
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
return q_packed
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
def get_weight_perm(num_bits: int, is_a_8bit: bool = False):
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
if is_a_8bit:
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
4 * (i % 4),
|
||||
4 * (i % 4) + 1,
|
||||
4 * (i % 4) + 2,
|
||||
4 * (i % 4) + 3,
|
||||
4 * (i % 4 + 4),
|
||||
4 * (i % 4 + 4) + 1,
|
||||
4 * (i % 4 + 4) + 2,
|
||||
4 * (i % 4 + 4) + 3,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(2):
|
||||
perm_list.extend([p + 512 * j for p in perm1])
|
||||
else:
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = np.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
if is_a_8bit: # noqa: SIM108
|
||||
interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
|
||||
else:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
if is_a_8bit: # noqa: SIM108
|
||||
interleave = np.array([0, 1, 2, 3])
|
||||
else:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
@ -99,7 +131,10 @@ def marlin_quantize(
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: torch.Tensor | None = None,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
|
||||
size_k, size_n = w.shape
|
||||
num_bits = quant_type.size_bits
|
||||
|
||||
@ -120,9 +155,15 @@ def marlin_quantize(
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
weight_perm = get_weight_perm(num_bits, is_a_8bit)
|
||||
marlin_q_w = marlin_weights(
|
||||
q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit
|
||||
)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
|
||||
|
||||
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8:
|
||||
ops.marlin_int4_fp8_preprocess(marlin_q_w, inplace=True)
|
||||
marlin_s = marlin_s * 512
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||
@ -132,7 +173,13 @@ def marlin_quantize(
|
||||
return res_list
|
||||
|
||||
|
||||
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
|
||||
def awq_marlin_quantize(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int
|
||||
# Quantize with zp
|
||||
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
|
||||
|
||||
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4:
|
||||
repeated_zp = zp.repeat_interleave(group_size, 0)
|
||||
q_w_old = q_w
|
||||
q_w = q_w_old - repeated_zp
|
||||
q_w[q_w < 0] = 15 - q_w_old[q_w < 0]
|
||||
s = s * 512
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit=is_a_8bit
|
||||
)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
|
||||
marlin_zp = marlin_zero_points(
|
||||
zp, num_groups, size_n, quant_type.size_bits, is_a_8bit=is_a_8bit
|
||||
)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user