mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:37:12 +08:00
[Kernel][Quantization][MoE] add marlin kernel support for turing (sm75) (#29901)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
eaa82a709a
commit
ce96857fdd
109
CMakeLists.txt
109
CMakeLists.txt
@ -357,6 +357,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# marlin arches for fp16 output
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
|
||||
cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||
# marlin arches for fp8 input
|
||||
@ -364,8 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
# marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
if (MARLIN_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin kernels we automatically generate sources for various
|
||||
@ -406,25 +410,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
if (MARLIN_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_SM75_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_SM75_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_SM75_KERNEL_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
@ -446,14 +464,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
set_source_files_properties(${MARLIN_SRCS}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
@ -980,12 +998,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# note that we always set `use_atomic_add=False` for moe marlin now,
|
||||
# so we don't need 9.0 for bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
|
||||
# moe marlin has limited support for turing
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}")
|
||||
# moe marlin arches for fp8 input
|
||||
# - sm80 doesn't support fp8 computation
|
||||
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
|
||||
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
# moe marlin arches for other files
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_OTHER_ARCHS)
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
@ -1026,16 +1048,29 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
endif()
|
||||
|
||||
if (MARLIN_MOE_SM75_ARCHS)
|
||||
file(GLOB MARLIN_MOE_SM75_SRC "csrc/moe/marlin_moe_wna16/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SM75_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_SM75_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_SM75_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SM75_SRC})
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
|
||||
|
||||
if (MARLIN_MOE_FP8_ARCHS)
|
||||
file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
|
||||
@ -1049,7 +1084,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
set(MARLIN_MOE_OTHER_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_OTHER_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_OTHER_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_MOE_OTHER_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_OTHER_SRC}")
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_OTHER_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
" in CUDA target architectures")
|
||||
|
||||
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -157,6 +163,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -174,6 +181,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -197,78 +206,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "quantization/gptq_marlin/marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -84,146 +85,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
|
||||
asm volatile(
|
||||
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -439,9 +300,20 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
|
||||
@ -618,7 +490,22 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
|
||||
if constexpr (moe_block_size >= 16)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
|
||||
if constexpr (moe_block_size >= 8)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
|
||||
if constexpr (moe_block_size >= 4)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
|
||||
if constexpr (moe_block_size >= 2)
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
|
||||
|
||||
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
|
||||
block_num_valid_tokens = local_count;
|
||||
#else
|
||||
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
|
||||
#endif
|
||||
|
||||
if (lane_id == 0)
|
||||
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
|
||||
@ -1018,10 +905,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1545,11 +1428,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1583,10 +1468,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -2132,6 +2019,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int is_zp_float, bool is_a_8bit, int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
bool is_a_8bit, int stages, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
is_a_8bit, stages, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
1
csrc/quantization/gptq_marlin/.gitignore
vendored
1
csrc/quantization/gptq_marlin/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
kernel_*.cu
|
||||
|
||||
@ -67,7 +67,7 @@ where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
|
||||
@ -10,6 +10,8 @@ import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
SUPPORT_SM75 = False
|
||||
SUPPORT_SM80 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
if arch >= 80:
|
||||
SUPPORT_SM80 = True
|
||||
if arch == 75:
|
||||
SUPPORT_SM75 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
@ -166,6 +172,7 @@ def remove_old_kernels():
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
sm_75_result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
@ -184,6 +191,8 @@ def generate_new_kernels():
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
|
||||
sm_75_result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
@ -207,78 +216,89 @@ def generate_new_kernels():
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"stages": 4,
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if SUPPORT_SM80:
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
|
||||
config_sm75 = config.copy()
|
||||
config_sm75["stages"] = 2
|
||||
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
for result_dict_tmp in [result_dict, sm_75_result_dict]:
|
||||
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
|
||||
all_template_str_list = []
|
||||
if not config_list:
|
||||
continue
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"stages == {config['stages']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
filename = filename.lower()
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
elif result_dict_tmp is sm_75_result_dict:
|
||||
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
|
||||
@ -37,7 +37,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
@ -148,7 +148,7 @@ typedef struct {
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool has_act_order, bool is_k_full, int stages) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
@ -166,28 +166,29 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
return tb_scales * stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit,
|
||||
int stages) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
@ -196,8 +197,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
group_size, has_act_order, is_k_full, stages);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
@ -217,7 +218,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
int has_zp, bool is_zp_float, bool is_a_8bit, int stages,
|
||||
int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
@ -242,7 +244,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
@ -251,7 +253,7 @@ MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int threads, bool is_zp_float, int stages) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
@ -265,7 +267,8 @@ exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
bool has_zp, bool is_zp_float, int is_a_8bit, int stages,
|
||||
int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@ -280,13 +283,15 @@ exec_config_t determine_exec_config(
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, is_a_8bit, stages);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
@ -297,14 +302,10 @@ exec_config_t determine_exec_config(
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
th_config.num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
|
||||
// int n_tiles = prob_n / th_config.thread_n;
|
||||
// int k_tiles = prob_k / th_config.thread_k;
|
||||
|
||||
return {1, th_config};
|
||||
}
|
||||
|
||||
@ -321,6 +322,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@ -389,8 +391,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
|
||||
"marlin kernel only support Turing or newer GPUs.");
|
||||
int stages = 4;
|
||||
if (major_capability == 7 && minor_capability == 5) {
|
||||
stages = 2;
|
||||
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
|
||||
"Turing only support FP16 or INT8 activation.");
|
||||
}
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
@ -431,7 +439,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem,
|
||||
sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
@ -440,7 +449,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
is_a_8bit, stages, max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
@ -466,7 +475,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
TORCH_CHECK(
|
||||
is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order, is_k_full,
|
||||
has_zp, is_zp_float, max_shared_mem_new),
|
||||
has_zp, is_zp_float, is_a_8bit, stages,
|
||||
max_shared_mem_new),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
@ -475,12 +485,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", prob_m_split = ", prob_m_split, ", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
num_threads, is_zp_float, stages);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#ifndef _marlin_cuh
|
||||
#define _marlin_cuh
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int32_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int32_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int64_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int64_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {}
|
||||
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
269
csrc/quantization/gptq_marlin/marlin_mma.h
Normal file
@ -0,0 +1,269 @@
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (!std::is_same<scalar_t, half>::value || k_size != 16) {
|
||||
static_assert(!use_fp16_accum);
|
||||
}
|
||||
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value && !use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
#else
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value &&
|
||||
use_fp16_accum) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
#else
|
||||
uint32_t* c = reinterpret_cast<uint32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]));
|
||||
#endif
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[0]), "=r"(c[1])
|
||||
: "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1]));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
|
||||
: "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3]));
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
@ -26,6 +26,7 @@
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "dequant.h"
|
||||
#include "marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -35,7 +36,7 @@
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
|
||||
@ -75,137 +76,6 @@ __global__ void Marlin(
|
||||
|
||||
#else
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
|
||||
"f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
|
||||
"r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
} else if (k_size == 32) {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <vllm::ScalarTypeId type_id, int k_size = 16>
|
||||
__device__ inline void mma_trans(
|
||||
const typename MarlinScalarType<type_id>::FragA& a_frag,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b,
|
||||
const typename MarlinScalarType<type_id>::FragB& frag_b2,
|
||||
typename MarlinScalarType<type_id>::FragC& frag_c) {
|
||||
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
|
||||
if constexpr (k_size == 16) {
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
|
||||
"f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
|
||||
"r"(c[3]));
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
|
||||
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
|
||||
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
|
||||
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
template <int count, vllm::ScalarTypeId type_id>
|
||||
@ -415,6 +285,17 @@ __global__ void Marlin(
|
||||
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
// Turing TensorCore only supports fp16 and int8
|
||||
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
using Adtype = MarlinScalarType<a_type_id>;
|
||||
using Cdtype = MarlinScalarType<c_type_id>;
|
||||
const int4* A = A0;
|
||||
@ -873,10 +754,6 @@ __global__ void Marlin(
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
// shared memory reused by reduction should be smaller than
|
||||
// shared memory used by weight.
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
@ -1395,11 +1272,13 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
if constexpr (m_block_size_8) {
|
||||
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
|
||||
frag_c[i][j][0]);
|
||||
} else {
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
|
||||
frag_c[i][j][0]);
|
||||
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
|
||||
frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1433,10 +1312,12 @@ __global__ void Marlin(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[0],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
|
||||
mma<a_type_id, false, 32>(
|
||||
frag_a[k2][i], frag_b[1],
|
||||
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
|
||||
}
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
@ -1956,6 +1837,21 @@ __global__ void Marlin(
|
||||
// While this pattern may not be the most readable, other ways of writing
|
||||
// the loop seemed to noticeably worse performance after compilation.
|
||||
if (slice_iters == 0) {
|
||||
// convert fp16 accum to fp32 for reduction
|
||||
if constexpr (use_fp16_accum) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
|
||||
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
|
||||
scalar_t* frag_c_part_half =
|
||||
reinterpret_cast<scalar_t*>(frag_c_part_float);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (is_a_8bit) {
|
||||
float frag_a_s[2 * thread_m_blocks];
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
|
||||
@ -253,7 +253,7 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
|
||||
@ -181,7 +181,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
|
||||
@ -871,7 +871,7 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user