[Kernel][Quantization][MoE] add marlin kernel support for turing (sm75) (#29901)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin 2025-12-17 06:35:28 +08:00 committed by GitHub
parent eaa82a709a
commit ce96857fdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 729 additions and 513 deletions

View File

@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# marlin arches for fp16 output
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
# marlin arches for fp8 input
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
# marlin arches for other files
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)
if (MARLIN_OTHER_ARCHS)
#
# For the Marlin kernels we automatically generate sources for various
@ -406,6 +410,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()
if (MARLIN_ARCHS)
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
@ -425,6 +430,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
endif()
if (MARLIN_SM75_ARCHS)
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
endif()
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}")
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
set_source_files_properties(${MARLIN_SRCS}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
else()
message(STATUS "Not building Marlin kernels as no compatible archs found"
" in CUDA target architectures")
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# note that we always set `use_atomic_add=False` for moe marlin now,
# so we don't need 9.0 for bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
# moe marlin has limited support for turing
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
# moe marlin arches for fp8 input
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
# moe marlin arches for other files
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_OTHER_ARCHS)
#
# For the Marlin MOE kernels we automatically generate sources for various
@ -1026,8 +1048,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif()
if (MARLIN_MOE_ARCHS)
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
@ -1036,6 +1058,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
endif()
if (MARLIN_MOE_SM75_ARCHS)
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SM75_SRC}"
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
endif()
if (MARLIN_MOE_FP8_ARCHS)
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_OTHER_SRC}"
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
" in CUDA target architectures")

View File

@ -1,2 +1,3 @@
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu

View File

@ -10,6 +10,8 @@ import jinja2
ARCHS = []
SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
@ -157,6 +163,7 @@ def remove_old_kernels():
def generate_new_kernels():
result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
@ -174,6 +181,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
@ -197,17 +206,25 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"stages": 4,
"group_blocks": group_blocks,
"is_zp_float": "false",
}
if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
for result_dict_tmp in [result_dict, sm_75_result_dict]:
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
all_template_str_list = []
if not config_list:
continue
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
@ -229,6 +246,7 @@ def generate_new_kernels():
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
@ -262,6 +280,8 @@ def generate_new_kernels():
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"

View File

@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
@ -84,146 +85,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id>
@ -439,9 +300,20 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
@ -618,7 +490,22 @@ __global__ void Marlin(
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if constexpr (moe_block_size >= 16)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
if constexpr (moe_block_size >= 8)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
if constexpr (moe_block_size >= 4)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
if constexpr (moe_block_size >= 2)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
block_num_valid_tokens = local_count;
#else
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
#endif
if (lane_id == 0)
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
@ -1018,10 +905,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads.
@ -1545,11 +1428,13 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
}
}
}
@ -1583,9 +1468,11 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
@ -2132,6 +2019,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];

View File

@ -142,7 +142,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
return tb_scales * stages;
}
}
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float, bool is_a_8bit) {
int is_zp_float, bool is_a_8bit, int stages) {
int pack_factor = 32 / num_bits;
// Get B size
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem, bool is_a_8bit) {
bool is_a_8bit, int stages, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size =
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem;
}
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_a_8bit) {
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem - 512)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
is_a_8bit, stages);
int group_blocks = 0;
if (!has_act_order) {
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue;
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg;
}
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, is_a_8bit),
is_a_8bit, stages, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,

View File

@ -1,2 +1,3 @@
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu

View File

@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.

View File

@ -10,6 +10,8 @@ import jinja2
ARCHS = []
SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
@ -166,6 +172,7 @@ def remove_old_kernels():
def generate_new_kernels():
result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
@ -184,6 +191,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
@ -207,17 +216,25 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"stages": 4,
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
for result_dict_tmp in [result_dict, sm_75_result_dict]:
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
all_template_str_list = []
if not config_list:
continue
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
@ -239,6 +256,7 @@ def generate_new_kernels():
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
@ -272,6 +290,8 @@ def generate_new_kernels():
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"

View File

@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
@ -148,7 +148,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
return tb_scales * stages;
}
}
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int has_zp, bool is_zp_float, bool is_a_8bit,
int stages) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem;
}
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem - 512)) {
is_zp_float, is_a_8bit, stages,
max_shared_mem - 512)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp,
is_zp_float, is_a_8bit, stages);
int group_blocks = 0;
if (!has_act_order) {
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue;
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
// int n_tiles = prob_n / th_config.thread_n;
// int k_tiles = prob_k / th_config.thread_k;
return {1, th_config};
}
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
bool is_a_8bit = a_type.size_bits() == 8;
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
sms);
thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n *
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) {
is_a_8bit, stages, max_shared_mem_new)) {
thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg};
}
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK(
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
prob_k, num_bits, group_size, has_act_order, is_k_full,
has_zp, is_zp_float, max_shared_mem_new),
has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem_new),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new);
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,

View File

@ -1,17 +1,19 @@
#pragma once
#include <torch/all.h>
#ifndef _marlin_cuh
#define _marlin_cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#endif
namespace MARLIN_NAMESPACE_NAME {
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int32_t*>(smem_ptr)[0] =
reinterpret_cast<const int32_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int64_t*>(smem_ptr)[0] =
reinterpret_cast<const int64_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
__device__ inline void cp_async_fence() {}
template <int n>
__device__ inline void cp_async_wait() {}
#else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
#endif
} // namespace MARLIN_NAMESPACE_NAME
#endif

View File

@ -0,0 +1,269 @@
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
static_assert(!use_fp16_accum);
}
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
#else
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, half>::value &&
use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
#else
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]));
#endif
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}
}
}
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
static_assert(!use_fp16_accum);
}
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
#else
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, half>::value &&
use_fp16_accum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
#else
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]));
#endif
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[0]), "=r"(c[1])
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
asm volatile(
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=r"(c[2]), "=r"(c[3])
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
#endif
}
}
}
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -26,6 +26,7 @@
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "dequant.h"
#include "marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
@ -75,137 +76,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id>
@ -415,6 +285,17 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
const int4* A = A0;
@ -873,10 +754,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads.
@ -1395,11 +1272,13 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
}
}
}
@ -1433,9 +1312,11 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
@ -1956,6 +1837,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];

View File

@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:

View File

@ -253,7 +253,7 @@ class Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:

View File

@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:

View File

@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
@classmethod
def get_min_capability(cls) -> int:
return 80
return 75
@classmethod
def override_quantization_method(