mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 07:47:03 +08:00
[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
parent
f62cad6431
commit
1d0c9d6b2d
@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# are not supported by Machete yet.
|
# are not supported by Machete yet.
|
||||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_ARCHS)
|
if (MARLIN_ARCHS)
|
||||||
|
|
||||||
|
#
|
||||||
|
# For the Marlin kernels we automatically generate sources for various
|
||||||
|
# preselected input type pairs and schedules.
|
||||||
|
# Generate sources:
|
||||||
|
set(MARLIN_GEN_SCRIPT
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
|
||||||
|
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
|
||||||
|
|
||||||
|
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
|
||||||
|
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
|
||||||
|
|
||||||
|
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
|
||||||
|
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E env
|
||||||
|
PYTHONPATH=$PYTHONPATH
|
||||||
|
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
|
||||||
|
RESULT_VARIABLE marlin_generation_result
|
||||||
|
OUTPUT_VARIABLE marlin_generation_result
|
||||||
|
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||||
|
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
|
||||||
|
)
|
||||||
|
|
||||||
|
if (NOT marlin_generation_result EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Marlin generation failed."
|
||||||
|
" Result: \"${marlin_generation_result}\""
|
||||||
|
"\nCheck the log for details: "
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
|
||||||
|
else()
|
||||||
|
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
|
||||||
|
CACHE STRING "Last run Marlin generate script hash" FORCE)
|
||||||
|
message(STATUS "Marlin generation completed successfully.")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(STATUS "Marlin generation script has not changed, skipping generation.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||||
|
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||||
|
|
||||||
|
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||||
|
|
||||||
set(MARLIN_SRCS
|
set(MARLIN_SRCS
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
|
||||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||||
@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND ${CMAKE_COMMAND} -E env
|
COMMAND ${CMAKE_COMMAND} -E env
|
||||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
PYTHONPATH=$PYTHONPATH
|
||||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||||
RESULT_VARIABLE moe_marlin_generation_result
|
RESULT_VARIABLE moe_marlin_generation_result
|
||||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||||
|
|||||||
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
Normal file
1
csrc/moe/marlin_moe_wna16/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
kernel_*.cu
|
||||||
@ -25,15 +25,13 @@ TEMPLATE = ("template __global__ void Marlin<"
|
|||||||
"{{thread_k_blocks}}, "
|
"{{thread_k_blocks}}, "
|
||||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
"{{stages}}, "
|
"{{stages}}, "
|
||||||
"{{'true' if has_act_order else 'false'}}, "
|
|
||||||
"{{'true' if has_zp else 'false'}}, "
|
|
||||||
"{{group_blocks}}, "
|
"{{group_blocks}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"( MARLIN_KERNEL_PARAMS );")
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||||
|
|
||||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
@ -52,21 +50,29 @@ def remove_old_kernels():
|
|||||||
|
|
||||||
def generate_new_kernels():
|
def generate_new_kernels():
|
||||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||||
has_zp = "B" not in scalar_type
|
|
||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||||
|
|
||||||
has_act_order = group_blocks == 0
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if has_zp and has_act_order:
|
if group_blocks == 0 and scalar_type not in [
|
||||||
|
"vllm::kU4B8", "vllm::kU8B128"
|
||||||
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
|
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||||
|
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||||
continue
|
continue
|
||||||
if m_blocks > 1 and thread_configs[0] != 64:
|
if m_blocks > 1 and thread_configs[0] != 64:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# we only support channelwise quantization and group_size == 128
|
||||||
|
# for fp8
|
||||||
|
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||||
|
continue
|
||||||
|
|
||||||
k_blocks = thread_configs[0] // 16
|
k_blocks = thread_configs[0] // 16
|
||||||
n_blocks = thread_configs[1] // 16
|
n_blocks = thread_configs[1] // 16
|
||||||
threads = thread_configs[2]
|
threads = thread_configs[2]
|
||||||
@ -82,8 +88,6 @@ def generate_new_kernels():
|
|||||||
thread_k_blocks=k_blocks,
|
thread_k_blocks=k_blocks,
|
||||||
m_block_size_8=m_blocks == 0.5,
|
m_block_size_8=m_blocks == 0.5,
|
||||||
stages="pipe_stages",
|
stages="pipe_stages",
|
||||||
has_act_order=has_act_order,
|
|
||||||
has_zp=has_zp,
|
|
||||||
group_blocks=group_blocks,
|
group_blocks=group_blocks,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,7 +18,7 @@
|
|||||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||||
bool use_fp32_reduce
|
bool use_fp32_reduce, int max_shared_mem
|
||||||
|
|
||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
@ -33,11 +33,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
|||||||
// only works when thread_m_blocks == 1
|
// only works when thread_m_blocks == 1
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const int group_blocks, // number of consecutive 16x16 blocks
|
||||||
const bool has_zp, // whether zero-points are enabled
|
// with a separate quantization scale
|
||||||
const int group_blocks, // number of consecutive 16x16 blocks
|
const bool is_zp_float // is zero point of float16 type?
|
||||||
// with a separate quantization scale
|
|
||||||
const bool is_zp_float // is zero point of float16 type?
|
|
||||||
>
|
>
|
||||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
|
|
||||||
#include "quantization/gptq_marlin/marlin.cuh"
|
#include "quantization/gptq_marlin/marlin.cuh"
|
||||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||||
|
#include "quantization/gptq_marlin/dequant.h"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
|||||||
// only works when thread_m_blocks == 1
|
// only works when thread_m_blocks == 1
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const int group_blocks, // number of consecutive 16x16 blocks
|
||||||
const bool has_zp, // whether zero-points are enabled
|
// with a separate quantization scale
|
||||||
const int group_blocks, // number of consecutive 16x16 blocks
|
const bool is_zp_float // is zero point of float16 type?
|
||||||
// with a separate quantization scale
|
|
||||||
const bool is_zp_float // is zero point of float16 type?
|
|
||||||
>
|
>
|
||||||
__global__ void Marlin(
|
__global__ void Marlin(
|
||||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
@ -77,8 +76,8 @@ __global__ void Marlin(
|
|||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int* locks, // extra global storage for barrier synchronization
|
int* locks, // extra global storage for barrier synchronization
|
||||||
bool use_atomic_add, // whether to use atomic add to reduce
|
bool use_atomic_add, // whether to use atomic add to reduce
|
||||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||||
) {}
|
int max_shared_mem) {}
|
||||||
|
|
||||||
} // namespace MARLIN_NAMESPACE_NAME
|
} // namespace MARLIN_NAMESPACE_NAME
|
||||||
|
|
||||||
@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup-table based 3-input logical operation; explicitly used for
|
|
||||||
// dequantization as the compiler does not seem to automatically recognize it in
|
|
||||||
// all cases.
|
|
||||||
template <int lut>
|
|
||||||
__device__ inline int lop3(int a, int b, int c) {
|
|
||||||
int res;
|
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
|
||||||
: "=r"(res)
|
|
||||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Constructs destination register by taking bytes from 2 sources (based on
|
|
||||||
// mask)
|
|
||||||
template <int start_byte, int mask>
|
|
||||||
__device__ inline uint32_t prmt(uint32_t a) {
|
|
||||||
uint32_t res;
|
|
||||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
|
||||||
: "=r"(res)
|
|
||||||
: "r"(a), "n"(start_byte), "n"(mask));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t, int bit>
|
|
||||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
|
|
||||||
int q, typename ScalarType<scalar_t>::FragB& frag_b);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
|
||||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
|
||||||
// with some small changes:
|
|
||||||
// - FP16:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
|
||||||
// - BF16:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
|
||||||
//
|
|
||||||
template <>
|
|
||||||
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
|
|
||||||
int q, typename ScalarType<half>::FragB& frag_b) {
|
|
||||||
const int LO = 0x000f000f;
|
|
||||||
const int HI = 0x00f000f0;
|
|
||||||
const int EX = 0x64006400;
|
|
||||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
|
||||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
|
||||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
|
||||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
|
||||||
// directly into `SUB` and `ADD`.
|
|
||||||
const int SUB = 0x64086408;
|
|
||||||
const int MUL = 0x2c002c00;
|
|
||||||
const int ADD = 0xd480d480;
|
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
|
||||||
*reinterpret_cast<const half2*>(&SUB));
|
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
|
||||||
*reinterpret_cast<const half2*>(&MUL),
|
|
||||||
*reinterpret_cast<const half2*>(&ADD));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
|
||||||
dequant<nv_bfloat16, 4>(int q,
|
|
||||||
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
|
|
||||||
static constexpr uint32_t MASK = 0x000f000f;
|
|
||||||
static constexpr uint32_t EX = 0x43004300;
|
|
||||||
|
|
||||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
|
||||||
|
|
||||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
|
||||||
q >>= 4;
|
|
||||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
|
||||||
|
|
||||||
static constexpr uint32_t MUL = 0x3F803F80;
|
|
||||||
static constexpr uint32_t ADD = 0xC308C308;
|
|
||||||
|
|
||||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
|
||||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
|
||||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
|
||||||
// bf16 Reference:
|
|
||||||
// - FP16:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
|
||||||
// - BF16:
|
|
||||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
|
||||||
//
|
|
||||||
template <>
|
|
||||||
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
|
|
||||||
int q, typename ScalarType<half>::FragB& frag_b) {
|
|
||||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
|
||||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
|
||||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
|
||||||
|
|
||||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
|
||||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
|
||||||
|
|
||||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
|
||||||
|
|
||||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
|
||||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
|
||||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
|
||||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
|
||||||
dequant<nv_bfloat16, 8>(int q,
|
|
||||||
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
|
|
||||||
float fp32_intermediates[4];
|
|
||||||
uint32_t* fp32_intermediates_casted =
|
|
||||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
|
||||||
|
|
||||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
|
||||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
|
||||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
|
||||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
|
||||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
|
||||||
|
|
||||||
fp32_intermediates[0] -= 8388736.f;
|
|
||||||
fp32_intermediates[1] -= 8388736.f;
|
|
||||||
fp32_intermediates[2] -= 8388736.f;
|
|
||||||
fp32_intermediates[3] -= 8388736.f;
|
|
||||||
|
|
||||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
|
||||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
|
||||||
fp32_intermediates_casted[1], 0x7632);
|
|
||||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
|
||||||
fp32_intermediates_casted[3], 0x7632);
|
|
||||||
|
|
||||||
return frag_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiply dequantized values by the corresponding quantization scale; used
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
// only for grouped quantization.
|
// only for grouped quantization.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
|||||||
// only works when thread_m_blocks == 1
|
// only works when thread_m_blocks == 1
|
||||||
const int stages, // number of stages for the async global->shared
|
const int stages, // number of stages for the async global->shared
|
||||||
// fetch pipeline
|
// fetch pipeline
|
||||||
const bool has_act_order, // whether act_order is enabled
|
const int group_blocks, // number of consecutive 16x16 blocks
|
||||||
const bool has_zp, // whether zero-points are enabled
|
// with a separate quantization scale
|
||||||
const int group_blocks, // number of consecutive 16x16 blocks
|
const bool is_zp_float // is zero point of float16 type?
|
||||||
// with a separate quantization scale
|
|
||||||
const bool is_zp_float // is zero point of float16 type?
|
|
||||||
>
|
>
|
||||||
__global__ void Marlin(
|
__global__ void Marlin(
|
||||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||||
@ -458,8 +317,8 @@ __global__ void Marlin(
|
|||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
int* locks, // extra global storage for barrier synchronization
|
int* locks, // extra global storage for barrier synchronization
|
||||||
bool use_atomic_add, // whether to use atomic add to reduce
|
bool use_atomic_add, // whether to use atomic add to reduce
|
||||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||||
) {
|
int max_shared_mem) {
|
||||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||||
// same size, which might involve multiple column "slices" (of width 16 *
|
// same size, which might involve multiple column "slices" (of width 16 *
|
||||||
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
||||||
@ -481,6 +340,8 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
extern __shared__ int4 sh[];
|
extern __shared__ int4 sh[];
|
||||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||||
|
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||||
|
constexpr bool has_act_order = group_blocks == 0;
|
||||||
|
|
||||||
constexpr int pack_factor = 32 / w_type.size_bits();
|
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||||
static_assert(thread_m_blocks == 1 || !m_block_size_8);
|
static_assert(thread_m_blocks == 1 || !m_block_size_8);
|
||||||
@ -534,13 +395,20 @@ __global__ void Marlin(
|
|||||||
int64_t B_expert_off = 0;
|
int64_t B_expert_off = 0;
|
||||||
|
|
||||||
int4* sh_block_sorted_ids_int4 = sh;
|
int4* sh_block_sorted_ids_int4 = sh;
|
||||||
|
int4* sh_rd_block_sorted_ids_int4 =
|
||||||
|
sh_block_sorted_ids_int4 + moe_block_size / 4;
|
||||||
|
int4* sh_block_topk_weights_int4 =
|
||||||
|
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
|
||||||
|
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
|
||||||
|
// but we pad to align to 256 bytes
|
||||||
|
int4* sh_new =
|
||||||
|
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
|
||||||
int32_t* sh_block_sorted_ids =
|
int32_t* sh_block_sorted_ids =
|
||||||
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
|
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
|
||||||
int4* sh_block_topk_weights_int4 =
|
int32_t* sh_rd_block_sorted_ids =
|
||||||
sh_block_sorted_ids_int4 + moe_block_size / 4;
|
reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
|
||||||
scalar_t2* sh_block_topk_weights =
|
scalar_t2* sh_block_topk_weights =
|
||||||
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
|
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
|
||||||
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
|
|
||||||
|
|
||||||
int32_t block_num_valid_tokens = 0;
|
int32_t block_num_valid_tokens = 0;
|
||||||
int32_t locks_off = 0;
|
int32_t locks_off = 0;
|
||||||
@ -584,6 +452,11 @@ __global__ void Marlin(
|
|||||||
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
|
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
|
||||||
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
|
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
sh_rd_block_sorted_ids[tid4 * 4 + i] =
|
||||||
|
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
|
||||||
|
|
||||||
if (mul_topk_weights) {
|
if (mul_topk_weights) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
@ -743,6 +616,7 @@ __global__ void Marlin(
|
|||||||
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
|
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
|
||||||
// constexpr int act_s_row_stride = 1;
|
// constexpr int act_s_row_stride = 1;
|
||||||
// int act_s_col_stride = act_s_row_stride * num_groups;
|
// int act_s_col_stride = act_s_row_stride * num_groups;
|
||||||
|
constexpr int act_s_max_num_groups = 32;
|
||||||
int act_s_col_stride = 1;
|
int act_s_col_stride = 1;
|
||||||
int act_s_col_warp_stride = act_s_col_stride * 8;
|
int act_s_col_warp_stride = act_s_col_stride * 8;
|
||||||
int tb_n_warps = thread_n_blocks / 4;
|
int tb_n_warps = thread_n_blocks / 4;
|
||||||
@ -758,9 +632,9 @@ __global__ void Marlin(
|
|||||||
int zp_gl_rd_delta = zp_gl_stride;
|
int zp_gl_rd_delta = zp_gl_stride;
|
||||||
|
|
||||||
// Global A read index of current thread.
|
// Global A read index of current thread.
|
||||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
|
||||||
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
|
||||||
// Shared write index of current thread.
|
// Shared write index of current thread.
|
||||||
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
(threadIdx.x % a_gl_rd_delta_o);
|
||||||
@ -774,8 +648,8 @@ __global__ void Marlin(
|
|||||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||||
b_gl_rd += b_sh_stride * slice_col;
|
b_gl_rd += b_sh_stride * slice_col;
|
||||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||||
|
|
||||||
// For act_order
|
// For act_order
|
||||||
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
||||||
@ -794,7 +668,7 @@ __global__ void Marlin(
|
|||||||
s_sh_stride * slice_col + threadIdx.x;
|
s_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int s_sh_wr = threadIdx.x;
|
auto s_sh_wr = threadIdx.x;
|
||||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||||
|
|
||||||
// Zero-points
|
// Zero-points
|
||||||
@ -807,7 +681,7 @@ __global__ void Marlin(
|
|||||||
zp_sh_stride * slice_col + threadIdx.x;
|
zp_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int zp_sh_wr = threadIdx.x;
|
auto zp_sh_wr = threadIdx.x;
|
||||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||||
|
|
||||||
// We use a different scale layout for grouped and column-wise quantization as
|
// We use a different scale layout for grouped and column-wise quantization as
|
||||||
@ -851,7 +725,7 @@ __global__ void Marlin(
|
|||||||
// each warp must also write a consecutive memory segment?
|
// each warp must also write a consecutive memory segment?
|
||||||
auto transform_a = [&](int i) {
|
auto transform_a = [&](int i) {
|
||||||
int row = i / a_gl_rd_delta_o;
|
int row = i / a_gl_rd_delta_o;
|
||||||
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
|
||||||
};
|
};
|
||||||
// Since the computation of this remapping is non-trivial and, due to our main
|
// Since the computation of this remapping is non-trivial and, due to our main
|
||||||
// loop unrolls, all shared memory accesses are static, we simply precompute
|
// loop unrolls, all shared memory accesses are static, we simply precompute
|
||||||
@ -879,12 +753,28 @@ __global__ void Marlin(
|
|||||||
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
||||||
|
|
||||||
// Shared memory storage for global fetch pipelines.
|
// Shared memory storage for global fetch pipelines.
|
||||||
int4* sh_a = sh_new;
|
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
|
||||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
constexpr int sh_b_size = stages * b_sh_stage;
|
||||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
int4* sh_b = sh_new;
|
||||||
|
int4* sh_red = sh_new;
|
||||||
|
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||||
|
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||||
|
: (stages * s_sh_stage);
|
||||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||||
int4* sh_red = sh_b;
|
// shared memory reused by reduction should be smaller than
|
||||||
|
// shared memory used by weight.
|
||||||
|
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||||
|
stages * b_sh_stage);
|
||||||
|
int4* sh_a = sh_s + sh_s_size;
|
||||||
|
constexpr int shm_size_used =
|
||||||
|
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
||||||
|
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||||
|
|
||||||
|
// all remaining shared memory is used to cache A (input)
|
||||||
|
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
|
||||||
|
int sh_a_max_row =
|
||||||
|
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
|
||||||
|
|
||||||
// Register storage for double buffer of shared memory reads.
|
// Register storage for double buffer of shared memory reads.
|
||||||
FragA frag_a[2][thread_m_blocks];
|
FragA frag_a[2][thread_m_blocks];
|
||||||
@ -905,15 +795,14 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
int sh_first_group_id = -1;
|
int sh_first_group_id = -1;
|
||||||
int sh_num_groups = -1;
|
int sh_num_groups = -1;
|
||||||
constexpr int sh_max_num_groups = 32;
|
|
||||||
|
|
||||||
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
|
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
|
||||||
int last_group_id) {
|
int last_group_id) {
|
||||||
sh_first_group_id = first_group_id;
|
sh_first_group_id = first_group_id;
|
||||||
sh_num_groups = last_group_id - first_group_id + 1;
|
sh_num_groups = last_group_id - first_group_id + 1;
|
||||||
|
|
||||||
if (sh_num_groups < sh_max_num_groups) {
|
if (sh_num_groups < act_s_max_num_groups) {
|
||||||
sh_num_groups = sh_max_num_groups;
|
sh_num_groups = act_s_max_num_groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sh_first_group_id + sh_num_groups > num_groups) {
|
if (sh_first_group_id + sh_num_groups > num_groups) {
|
||||||
@ -940,27 +829,31 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Asynchronously fetch the next A, B and s tile from global to the next
|
// Asynchronously fetch the next A, B and s tile from global to the next
|
||||||
// shared memory pipeline location.
|
// shared memory pipeline location.
|
||||||
int a_remaining_load_count_in_slice = stages;
|
bool should_load_a = true;
|
||||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
int max_num_stage_groups =
|
||||||
|
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
|
||||||
|
max_num_stage_groups = max(max_num_stage_groups, 1);
|
||||||
|
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
|
||||||
|
int pipe_a = 0) {
|
||||||
if (pred) {
|
if (pred) {
|
||||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
if (should_load_a) {
|
||||||
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 ||
|
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||||
a_remaining_load_count_in_slice > 0) {
|
|
||||||
a_remaining_load_count_in_slice--;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||||
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
|
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
|
||||||
int row = a_idx / a_gl_stride;
|
|
||||||
int64_t sorted_row = 0;
|
int64_t sorted_row = 0;
|
||||||
if (!m_block_size_8 || row < 8)
|
if (!m_block_size_8 || row < 8)
|
||||||
sorted_row = sh_block_sorted_ids[row] / top_k;
|
sorted_row = sh_rd_block_sorted_ids[row];
|
||||||
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
|
int64_t true_idx =
|
||||||
|
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
|
||||||
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
|
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
|
||||||
row < block_num_valid_tokens);
|
row < block_num_valid_tokens);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||||
@ -1063,8 +956,8 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
// Load the next sub-tile from the current location in the shared memory pipe
|
// Load the next sub-tile from the current location in the shared memory pipe
|
||||||
// into the current register buffer.
|
// into the current register buffer.
|
||||||
auto fetch_to_registers = [&](int k, int pipe) {
|
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
|
||||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++)
|
for (int i = 0; i < thread_m_blocks; i++)
|
||||||
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
|
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
|
||||||
@ -1109,12 +1002,17 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
} else if constexpr (group_blocks != -1) {
|
} else if constexpr (group_blocks != -1) {
|
||||||
if constexpr (group_blocks >= thread_k_blocks) {
|
if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
int4* sh_s_stage =
|
if (k % b_sh_wr_iters == 0) {
|
||||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
int4* sh_s_stage =
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||||
|
} else {
|
||||||
|
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||||
|
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
@ -1152,7 +1050,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
// Determine "position" inside the thread-block (based on warp and
|
// Determine "position" inside the thread-block (based on warp and
|
||||||
// thread-id)
|
// thread-id)
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps =
|
int n_warps =
|
||||||
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
||||||
|
|
||||||
@ -1161,7 +1059,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
cur_k += warp_row * 16;
|
cur_k += warp_row * 16;
|
||||||
|
|
||||||
int th_id = threadIdx.x % 32;
|
auto th_id = threadIdx.x % 32;
|
||||||
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
||||||
|
|
||||||
int s_col_shift =
|
int s_col_shift =
|
||||||
@ -1222,15 +1120,18 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
int4* sh_zp_stage =
|
if (k % b_sh_wr_iters == 0) {
|
||||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
int4* sh_zp_stage =
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
frag_qzp[k % 2][i] =
|
#pragma unroll
|
||||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
|
frag_qzp[k % 2][i] =
|
||||||
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
@ -1251,6 +1152,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||||
frag_qzp[k % 2][i] =
|
frag_qzp[k % 2][i] =
|
||||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
@ -1263,12 +1165,16 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
if constexpr (group_blocks != -1) {
|
if constexpr (group_blocks != -1) {
|
||||||
if constexpr (group_blocks >= thread_k_blocks) {
|
if constexpr (group_blocks >= thread_k_blocks) {
|
||||||
int4* sh_zp_stage =
|
if (k % b_sh_wr_iters == 0) {
|
||||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
int4* sh_zp_stage =
|
||||||
(pipe / (group_blocks / thread_k_blocks)));
|
sh_zp +
|
||||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||||
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
|
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
|
||||||
|
sh_zp_stage[zp_sh_rd];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
@ -1292,6 +1198,25 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||||
|
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||||
|
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||||
|
} else {
|
||||||
|
static_assert(has_zp && !is_zp_float);
|
||||||
|
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||||
|
// If (has_zp && !is_zp_float),
|
||||||
|
// we use not-zp version `dequant` function
|
||||||
|
// to improve numerical accuracy.
|
||||||
|
// Since both weight and zero point are dequanted using this logic,
|
||||||
|
// the final dequanted weight would be correct.
|
||||||
|
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||||
|
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||||
|
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||||
|
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Execute the actual tensor core matmul of a sub-tile.
|
// Execute the actual tensor core matmul of a sub-tile.
|
||||||
bool is_first_matmul_in_slice = true;
|
bool is_first_matmul_in_slice = true;
|
||||||
auto matmul = [&](int k) {
|
auto matmul = [&](int k) {
|
||||||
@ -1315,15 +1240,17 @@ __global__ void Marlin(
|
|||||||
zp_quant_1 = frag_qzp[k2][1];
|
zp_quant_1 = frag_qzp[k2][1];
|
||||||
}
|
}
|
||||||
|
|
||||||
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
|
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
|
||||||
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
|
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||||
|
|
||||||
frag_zp[0] = frag_zp_0[0];
|
|
||||||
frag_zp[1] = frag_zp_0[1];
|
|
||||||
frag_zp[2] = frag_zp_1[0];
|
|
||||||
frag_zp[3] = frag_zp_1[1];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if constexpr (has_zp && is_zp_float) {
|
||||||
|
if (is_new_zp) {
|
||||||
|
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||||
|
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -1342,8 +1269,8 @@ __global__ void Marlin(
|
|||||||
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
|
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||||
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
|
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||||
|
|
||||||
// Apply scale to frag_b0
|
// Apply scale to frag_b0
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
@ -1351,8 +1278,7 @@ __global__ void Marlin(
|
|||||||
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||||
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
|
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||||
|
|
||||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||||
int idx = (threadIdx.x / 4) % 2;
|
int idx = (threadIdx.x / 4) % 2;
|
||||||
scalar_t2 s2 = Dtype::nums2num2(
|
scalar_t2 s2 = Dtype::nums2num2(
|
||||||
@ -1361,18 +1287,12 @@ __global__ void Marlin(
|
|||||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||||
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
|
} else if constexpr (has_zp && group_blocks != -1) {
|
||||||
if (is_new_zp)
|
if (is_new_zp)
|
||||||
frag_zp[j] = __hmul2(frag_zp[j],
|
frag_zp[j] = __hmul2(frag_zp[j],
|
||||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||||
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
|
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
|
||||||
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
|
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
|
||||||
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
|
||||||
if (is_new_zp)
|
|
||||||
frag_zpf[k2][j] = __hmul2(
|
|
||||||
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
|
||||||
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
|
|
||||||
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
|
|
||||||
} else if constexpr (group_blocks != -1) {
|
} else if constexpr (group_blocks != -1) {
|
||||||
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
|
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
|
||||||
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
|
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
|
||||||
@ -1397,7 +1317,7 @@ __global__ void Marlin(
|
|||||||
auto thread_block_reduce = [&]() {
|
auto thread_block_reduce = [&]() {
|
||||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||||
if (red_off >= 1) {
|
if (red_off >= 1) {
|
||||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||||
@ -1731,7 +1651,7 @@ __global__ void Marlin(
|
|||||||
fetch_col_scale_to_shared();
|
fetch_col_scale_to_shared();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fetch_to_shared(i, i, i < slice_iters);
|
fetch_to_shared(i, i, i < slice_iters, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
zero_accums();
|
zero_accums();
|
||||||
@ -1740,8 +1660,10 @@ __global__ void Marlin(
|
|||||||
fetch_to_registers(0, 0);
|
fetch_to_registers(0, 0);
|
||||||
fetch_scales_to_registers(0, 0);
|
fetch_scales_to_registers(0, 0);
|
||||||
fetch_zp_to_registers(0, 0);
|
fetch_zp_to_registers(0, 0);
|
||||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
|
||||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
if constexpr (has_act_order) {
|
||||||
|
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
if (slice_iters) {
|
if (slice_iters) {
|
||||||
start_pipes();
|
start_pipes();
|
||||||
@ -1754,45 +1676,58 @@ __global__ void Marlin(
|
|||||||
// have even length meaning that the next iteration will always start at
|
// have even length meaning that the next iteration will always start at
|
||||||
// index 0.
|
// index 0.
|
||||||
|
|
||||||
|
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
|
||||||
|
stage_group_id++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < stages;) {
|
for (int pipe = 0; pipe < stages;) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||||
fetch_to_registers(k + 1, pipe % stages);
|
int idx =
|
||||||
fetch_scales_to_registers(k + 1, pipe);
|
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
|
||||||
fetch_zp_to_registers(k + 1, pipe);
|
? (pipe - stages)
|
||||||
if (k == b_sh_wr_iters - 2) {
|
: (pipe + stage_group_id * stages);
|
||||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
fetch_to_registers(k + 1, pipe % stages, idx);
|
||||||
slice_iters >= stages);
|
fetch_scales_to_registers(k + 1, pipe);
|
||||||
pipe++;
|
fetch_zp_to_registers(k + 1, pipe);
|
||||||
wait_for_stage();
|
if (k == b_sh_wr_iters - 2) {
|
||||||
init_same_group(pipe % stages);
|
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
|
||||||
|
? (pipe - 1)
|
||||||
|
: (pipe + (stage_group_id + 1) * stages - 1);
|
||||||
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||||
|
slice_iters >= stages, idx);
|
||||||
|
pipe++;
|
||||||
|
wait_for_stage();
|
||||||
|
init_same_group(pipe % stages);
|
||||||
|
}
|
||||||
|
matmul(k);
|
||||||
|
}
|
||||||
|
slice_iters--;
|
||||||
|
if (slice_iters == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a_gl_rd_col += a_gl_rd_delta_o * stages;
|
||||||
|
|
||||||
|
if constexpr (has_act_order) {
|
||||||
|
slice_k_start += tb_k * stages;
|
||||||
|
slice_k_start_shared_fetch += tb_k * stages;
|
||||||
|
int first_group_id = g_idx[slice_k_start];
|
||||||
|
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||||
|
if (last_g_idx >= prob_k) {
|
||||||
|
last_g_idx = prob_k - 1;
|
||||||
|
}
|
||||||
|
int last_group_id = g_idx[last_g_idx];
|
||||||
|
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||||
|
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||||
|
last_group_id);
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
matmul(k);
|
|
||||||
}
|
}
|
||||||
slice_iters--;
|
|
||||||
if (slice_iters == 0) {
|
if (slice_iters == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
a_remaining_load_count_in_slice = 0;
|
|
||||||
|
|
||||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
|
||||||
slice_k_start += tb_k * stages;
|
|
||||||
slice_k_start_shared_fetch += tb_k * stages;
|
|
||||||
|
|
||||||
if constexpr (has_act_order) {
|
|
||||||
int first_group_id = g_idx[slice_k_start];
|
|
||||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
|
||||||
if (last_g_idx >= prob_k) {
|
|
||||||
last_g_idx = prob_k - 1;
|
|
||||||
}
|
|
||||||
int last_group_id = g_idx[last_g_idx];
|
|
||||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
|
||||||
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process results and, if necessary, proceed to the next column slice.
|
// Process results and, if necessary, proceed to the next column slice.
|
||||||
// While this pattern may not be the most readable, other ways of writing
|
// While this pattern may not be the most readable, other ways of writing
|
||||||
@ -1877,15 +1812,30 @@ __global__ void Marlin(
|
|||||||
if (last || use_atomic_add)
|
if (last || use_atomic_add)
|
||||||
// only the last block in a slice actually writes the result
|
// only the last block in a slice actually writes the result
|
||||||
write_result();
|
write_result();
|
||||||
if (slice_row) a_remaining_load_count_in_slice = stages;
|
int old_slice_row = slice_row;
|
||||||
slice_row = 0;
|
slice_row = 0;
|
||||||
slice_col_par++;
|
slice_col_par++;
|
||||||
slice_col++;
|
slice_col++;
|
||||||
is_first_matmul_in_slice = true;
|
is_first_matmul_in_slice = true;
|
||||||
init_slice();
|
init_slice();
|
||||||
|
|
||||||
|
// Should we load A matrix in next slice?
|
||||||
|
// `slice_col == 0`: when move to a new moe block
|
||||||
|
// `old_slice_row > 0`:
|
||||||
|
// when the last slice is not starting from k_index == 0
|
||||||
|
// (only happen when it is the first slice of a threadblock)
|
||||||
|
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
|
||||||
|
// when the required shared memory size is larger than
|
||||||
|
// the remaining shared memory
|
||||||
|
if (slice_col == 0 || old_slice_row ||
|
||||||
|
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
|
||||||
|
should_load_a = true;
|
||||||
|
} else {
|
||||||
|
should_load_a = false;
|
||||||
|
}
|
||||||
|
|
||||||
if (slice_iters) {
|
if (slice_iters) {
|
||||||
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
|
||||||
(threadIdx.x % a_gl_rd_delta_o);
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||||
@ -1900,12 +1850,10 @@ __global__ void Marlin(
|
|||||||
slice_k_finish = slice_k_start + tb_k * slice_iters;
|
slice_k_finish = slice_k_start + tb_k * slice_iters;
|
||||||
slice_k_start_shared_fetch = slice_k_start;
|
slice_k_start_shared_fetch = slice_k_start;
|
||||||
slice_n_offset = act_s_col_tb_stride * slice_col;
|
slice_n_offset = act_s_col_tb_stride * slice_col;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
start_pipes();
|
start_pipes();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
|
|||||||
int base_k = 0;
|
int base_k = 0;
|
||||||
|
|
||||||
for (int i = 0; i < iters; i++) {
|
for (int i = 0; i < iters; i++) {
|
||||||
int cur_k = base_k + threadIdx.x;
|
auto cur_k = base_k + threadIdx.x;
|
||||||
int src_pos = perm_int_ptr[cur_k];
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
out_half[cur_k] = a_row_half[src_pos];
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
|
|||||||
|
|
||||||
if (rest) {
|
if (rest) {
|
||||||
if (threadIdx.x < rest) {
|
if (threadIdx.x < rest) {
|
||||||
int cur_k = base_k + threadIdx.x;
|
auto cur_k = base_k + threadIdx.x;
|
||||||
int src_pos = perm_int_ptr[cur_k];
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
out_half[cur_k] = a_row_half[src_pos];
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
|||||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||||
return load_groups * tb_n * 2;
|
return load_groups * tb_n * 2;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
int tb_scales = tb_groups * tb_n * 2;
|
int tb_scales = tb_groups * tb_n * 2;
|
||||||
|
|
||||||
@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
int thread_m_blocks, int prob_m, int prob_n,
|
||||||
int group_size, bool has_act_order, bool is_k_full,
|
int prob_k, int num_bits, int group_size,
|
||||||
int has_zp, int is_zp_float) {
|
bool has_act_order, bool is_k_full, int has_zp,
|
||||||
|
int is_zp_float) {
|
||||||
int pack_factor = 32 / num_bits;
|
int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
// Get B size
|
// Get B size
|
||||||
int tb_k = th_config.thread_k;
|
int tb_k = th_config.thread_k;
|
||||||
int tb_n = th_config.thread_n;
|
int tb_n = th_config.thread_n;
|
||||||
int tb_m = thread_m_blocks * 16;
|
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
|
||||||
|
|
||||||
// shm size for block_sorted_ids/block_topk_weights
|
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||||
int sh_block_meta_size = tb_m * 4 * 2;
|
int sh_block_meta_size = tb_m * 4;
|
||||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||||
|
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||||
int sh_s_size =
|
int sh_s_size =
|
||||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||||
group_size, has_act_order, is_k_full);
|
group_size, has_act_order, is_k_full);
|
||||||
@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
|||||||
sh_zp_size = sh_s_size / 2;
|
sh_zp_size = sh_s_size / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
||||||
sh_g_idx_size + sh_block_meta_size;
|
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
|
||||||
|
|
||||||
return total_size;
|
return total_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||||
int group_size, bool has_act_order, bool is_k_full,
|
int num_bits, int group_size, bool has_act_order,
|
||||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
bool is_k_full, int has_zp, int is_zp_float,
|
||||||
|
int max_shared_mem) {
|
||||||
// Sanity
|
// Sanity
|
||||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||||
th_config.num_threads == -1) {
|
th_config.num_threads == -1) {
|
||||||
@ -266,143 +268,113 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
|||||||
|
|
||||||
// Check that pipeline fits into cache
|
// Check that pipeline fits into cache
|
||||||
int cache_size = get_kernel_cache_size(
|
int cache_size = get_kernel_cache_size(
|
||||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||||
return cache_size <= max_shared_mem;
|
return cache_size <= max_shared_mem;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||||
NUM_THREADS, IS_ZP_FLOAT) \
|
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
is_zp_float == IS_ZP_FLOAT) { \
|
||||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
is_zp_float == IS_ZP_FLOAT) { \
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
|
||||||
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
|
||||||
IS_ZP_FLOAT>; \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
// this is the most common cases
|
||||||
false) \
|
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||||
NUM_THREADS, false) \
|
// ACT: cases for act order case (group_blocks == 0)
|
||||||
\
|
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
|
||||||
NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
\
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
\
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
\
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||||
NUM_THREADS, false) \
|
|
||||||
\
|
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
\
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
|
||||||
NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define COMMON_GET_IF(W_TYPE) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
NUM_THREADS, false) \
|
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
false) \
|
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
|
||||||
false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
|
||||||
false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
|
||||||
NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
|
||||||
NUM_THREADS, false) \
|
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||||
\
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
\
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
|
||||||
NUM_THREADS, false) \
|
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
NUM_THREADS, false) \
|
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
\
|
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
|
||||||
NUM_THREADS, false) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
|
||||||
NUM_THREADS, false)
|
|
||||||
|
|
||||||
// We currently have 4-bit models only with group_blocks == 4
|
// We currently have 4-bit models only with group_blocks == 4
|
||||||
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||||
true) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
|
||||||
NUM_THREADS, true) \
|
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||||
NUM_THREADS, true) \
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||||
NUM_THREADS, true) \
|
|
||||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
#define FZP_GET_IF(W_TYPE) \
|
||||||
NUM_THREADS, true)
|
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
|
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
|
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
|
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
|
|
||||||
|
// We currently have 4-bit models only with group_blocks == 4
|
||||||
|
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||||
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define ACT_GET_IF(W_TYPE) \
|
||||||
|
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
|
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
|
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||||
|
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||||
@ -415,23 +387,15 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
|||||||
auto kernel = MarlinDefault;
|
auto kernel = MarlinDefault;
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
|
||||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
|
||||||
|
|
||||||
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
COMMON_GET_IF(vllm::kU4)
|
||||||
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
COMMON_GET_IF(vllm::kU4B8)
|
||||||
|
COMMON_GET_IF(vllm::kU8B128)
|
||||||
|
|
||||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
|
||||||
|
|
||||||
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
ACT_GET_IF(vllm::kU4B8)
|
||||||
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
ACT_GET_IF(vllm::kU8B128)
|
||||||
|
|
||||||
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
|
||||||
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
|
||||||
|
|
||||||
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
|
||||||
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
|
||||||
|
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
@ -457,19 +421,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
|||||||
for (int i = 0; i < thread_configs_size; i++) {
|
for (int i = 0; i < thread_configs_size; i++) {
|
||||||
thread_config_t th_config = thread_configs[i];
|
thread_config_t th_config = thread_configs[i];
|
||||||
|
|
||||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||||
is_zp_float, max_shared_mem)) {
|
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int cache_size = get_kernel_cache_size(
|
int cache_size = get_kernel_cache_size(
|
||||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||||
|
|
||||||
int group_blocks = 0;
|
int group_blocks = 0;
|
||||||
if (!has_act_order) {
|
if (!has_act_order) {
|
||||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
group_blocks = group_size == -1 ? -1 : (group_size / 16);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = get_marlin_kernel<scalar_t>(
|
auto kernel = get_marlin_kernel<scalar_t>(
|
||||||
@ -515,14 +479,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
bool m_block_size_8 = moe_block_size == 8;
|
bool m_block_size_8 = moe_block_size == 8;
|
||||||
|
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(q_type == vllm::kU4,
|
||||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
|
||||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
q_type == vllm::kFE4M3fn,
|
||||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||||
q_type.str());
|
"False. Got = ",
|
||||||
|
q_type.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
@ -631,18 +595,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
int thread_k_blocks = thread_k / 16;
|
int thread_k_blocks = thread_k / 16;
|
||||||
int thread_n_blocks = thread_n / 16;
|
int thread_n_blocks = thread_n / 16;
|
||||||
|
|
||||||
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
TORCH_CHECK(
|
||||||
prob_k, num_bits, group_size, has_act_order,
|
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||||
", thread_k = ", thread_tfg.thread_k,
|
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||||
", thread_n = ", thread_tfg.thread_n,
|
", thread_k = ", thread_tfg.thread_k,
|
||||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
", thread_n = ", thread_tfg.thread_n,
|
||||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
|
||||||
", group_size = ", group_size,
|
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
", group_size = ", group_size, ", has_act_order = ", has_act_order,
|
||||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
|
||||||
", max_shared_mem = ", max_shared_mem);
|
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
|
||||||
|
|
||||||
auto kernel = get_marlin_kernel<scalar_t>(
|
auto kernel = get_marlin_kernel<scalar_t>(
|
||||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||||
@ -666,7 +630,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -841,10 +805,11 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
b_q_type == vllm::kU4,
|
b_q_type == vllm::kU4,
|
||||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
b_q_type == vllm::kFE4M3fn,
|
||||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||||
b_q_type.str());
|
"False. Got = ",
|
||||||
|
b_q_type.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_zp && is_zp_float) {
|
if (has_zp && is_zp_float) {
|
||||||
|
|||||||
1
csrc/quantization/gptq_marlin/.gitignore
vendored
Normal file
1
csrc/quantization/gptq_marlin/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
kernel_*.cu
|
||||||
291
csrc/quantization/gptq_marlin/dequant.h
Normal file
291
csrc/quantization/gptq_marlin/dequant.h
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
|
||||||
|
#include "marlin_dtypes.cuh"
|
||||||
|
|
||||||
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
|
||||||
|
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||||
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
|
// all cases.
|
||||||
|
template <int lut>
|
||||||
|
__device__ inline int lop3(int a, int b, int c) {
|
||||||
|
int res;
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constructs destination register by taking bytes from 2 sources (based on
|
||||||
|
// mask)
|
||||||
|
template <int start_byte, int mask>
|
||||||
|
__device__ inline uint32_t prmt(uint32_t a) {
|
||||||
|
uint32_t res;
|
||||||
|
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "n"(start_byte), "n"(mask));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
|
||||||
|
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||||
|
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||||
|
// with some small changes:
|
||||||
|
// - FP16:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||||
|
// - BF16:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||||
|
//
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
// clang-format off
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
// clang-format on
|
||||||
|
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||||
|
// directly into `SUB` and `ADD`.
|
||||||
|
const int SUB = 0x64086408;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd480d480;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
// clang-format off
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
// clang-format on
|
||||||
|
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||||
|
// directly into `SUB` and `ADD`.
|
||||||
|
const int SUB = 0x64006400;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd400d400;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&MUL),
|
||||||
|
*reinterpret_cast<const half2*>(&ADD));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
|
||||||
|
int q, nv_bfloat162* frag_b) {
|
||||||
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
// clang-format off
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
q >>= 4;
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
static constexpr uint32_t MUL = 0x3F803F80;
|
||||||
|
static constexpr uint32_t ADD = 0xC308C308;
|
||||||
|
|
||||||
|
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
|
||||||
|
int q, nv_bfloat162* frag_b) {
|
||||||
|
static constexpr uint32_t MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t EX = 0x43004300;
|
||||||
|
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
// clang-format off
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
q >>= 4;
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
static constexpr uint32_t MUL = 0x3F803F80;
|
||||||
|
static constexpr uint32_t ADD = 0xC300C300;
|
||||||
|
|
||||||
|
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||||
|
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||||
|
// bf16 Reference:
|
||||||
|
// - FP16:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||||
|
// - BF16:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||||
|
//
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
|
||||||
|
half2* frag_b) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||||
|
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||||
|
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
|
||||||
|
int q, nv_bfloat162* frag_b) {
|
||||||
|
float fp32_intermediates[4];
|
||||||
|
uint32_t* fp32_intermediates_casted =
|
||||||
|
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||||
|
|
||||||
|
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||||
|
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||||
|
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||||
|
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||||
|
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||||
|
|
||||||
|
fp32_intermediates[0] -= 8388736.f;
|
||||||
|
fp32_intermediates[1] -= 8388736.f;
|
||||||
|
fp32_intermediates[2] -= 8388736.f;
|
||||||
|
fp32_intermediates[3] -= 8388736.f;
|
||||||
|
|
||||||
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||||
|
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||||
|
fp32_intermediates_casted[1], 0x7632);
|
||||||
|
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||||
|
fp32_intermediates_casted[3], 0x7632);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
|
||||||
|
int q, nv_bfloat162* frag_b) {
|
||||||
|
float fp32_intermediates[4];
|
||||||
|
uint32_t* fp32_intermediates_casted =
|
||||||
|
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||||
|
|
||||||
|
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||||
|
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||||
|
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||||
|
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||||
|
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||||
|
|
||||||
|
fp32_intermediates[0] -= 8388608.f;
|
||||||
|
fp32_intermediates[1] -= 8388608.f;
|
||||||
|
fp32_intermediates[2] -= 8388608.f;
|
||||||
|
fp32_intermediates[3] -= 8388608.f;
|
||||||
|
|
||||||
|
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
|
||||||
|
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||||
|
fp32_intermediates_casted[1], 0x7632);
|
||||||
|
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||||
|
fp32_intermediates_casted[3], 0x7632);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
|
||||||
|
half2* frag_b) {
|
||||||
|
// Constants for FP8 (E4M3) and FP16 formats
|
||||||
|
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
|
||||||
|
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||||
|
|
||||||
|
// Calculate MASK for extracting mantissa and exponent
|
||||||
|
constexpr int MASK1 = 0x80000000;
|
||||||
|
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||||
|
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||||
|
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||||
|
// Final MASK value: 0x7F007F00
|
||||||
|
|
||||||
|
// Extract and shift FP8 values to FP16 format
|
||||||
|
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||||
|
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||||
|
|
||||||
|
// Construct and apply exponent bias
|
||||||
|
constexpr int BIAS_OFFSET =
|
||||||
|
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||||
|
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||||
|
|
||||||
|
// Convert to half2 and apply bias
|
||||||
|
// Note: reverse indexing is intentional because weights are permuted
|
||||||
|
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
|
||||||
|
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||||
|
int q, nv_bfloat162* frag_b) {
|
||||||
|
// Constants for FP8 (E4M3) and BF16 formats
|
||||||
|
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
|
||||||
|
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||||
|
|
||||||
|
// Calculate MASK for extracting mantissa and exponent
|
||||||
|
constexpr int MASK1 = 0x80000000;
|
||||||
|
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||||
|
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||||
|
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||||
|
// Final MASK value: 0x7F007F00
|
||||||
|
|
||||||
|
// Extract and shift FP8 values to BF16 format
|
||||||
|
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||||
|
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||||
|
|
||||||
|
// Construct and apply exponent bias
|
||||||
|
constexpr int BIAS_OFFSET =
|
||||||
|
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
|
||||||
|
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||||
|
// position
|
||||||
|
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||||
|
const nv_bfloat162 bias_reg =
|
||||||
|
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||||
|
|
||||||
|
// Convert to bfloat162 and apply bias
|
||||||
|
// Note: reverse indexing is intentional because weights are permuted
|
||||||
|
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
|
||||||
|
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace MARLIN_NAMESPACE_NAME
|
||||||
116
csrc/quantization/gptq_marlin/generate_kernels.py
Normal file
116
csrc/quantization/gptq_marlin/generate_kernels.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import glob
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
|
FILE_HEAD = """
|
||||||
|
// auto generated by generate.py
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#include "kernel.h"
|
||||||
|
#include "marlin_template.h"
|
||||||
|
|
||||||
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
TEMPLATE = ("template __global__ void Marlin<"
|
||||||
|
"{{scalar_t}}, "
|
||||||
|
"{{w_type_id}}, "
|
||||||
|
"{{threads}}, "
|
||||||
|
"{{thread_m_blocks}}, "
|
||||||
|
"{{thread_n_blocks}}, "
|
||||||
|
"{{thread_k_blocks}}, "
|
||||||
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
|
"{{stages}}, "
|
||||||
|
"{{group_blocks}}, "
|
||||||
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
|
"( MARLIN_KERNEL_PARAMS );")
|
||||||
|
|
||||||
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
|
# we don't add it to reduce wheel size.
|
||||||
|
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||||
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
||||||
|
(128, 64, 128)]
|
||||||
|
|
||||||
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
|
# group_blocks:
|
||||||
|
# = 0 : act order case
|
||||||
|
# = -1 : channelwise quantization
|
||||||
|
# > 0 : group_size=16*group_blocks
|
||||||
|
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||||
|
DTYPES = ["fp16", "bf16"]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_old_kernels():
|
||||||
|
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||||
|
subprocess.call(["rm", "-f", filename])
|
||||||
|
|
||||||
|
|
||||||
|
def generate_new_kernels():
|
||||||
|
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||||
|
all_template_str_list = []
|
||||||
|
|
||||||
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||||
|
|
||||||
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
|
if group_blocks == 0 and scalar_type not in [
|
||||||
|
"vllm::kU4B8", "vllm::kU8B128"
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
if thread_configs[2] == 256:
|
||||||
|
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||||
|
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||||
|
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||||
|
continue
|
||||||
|
if m_blocks > 1 and thread_configs[0] != 64:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# we only support channelwise quantization and group_size == 128
|
||||||
|
# for fp8
|
||||||
|
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
k_blocks = thread_configs[0] // 16
|
||||||
|
n_blocks = thread_configs[1] // 16
|
||||||
|
threads = thread_configs[2]
|
||||||
|
|
||||||
|
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||||
|
|
||||||
|
is_zp_float_list = [False]
|
||||||
|
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
||||||
|
group_blocks == 4:
|
||||||
|
# HQQ (is_zp_float = true) only supports
|
||||||
|
# 4bit quantization and fp16
|
||||||
|
is_zp_float_list.append(True)
|
||||||
|
|
||||||
|
for is_zp_float in is_zp_float_list:
|
||||||
|
template_str = jinja2.Template(TEMPLATE).render(
|
||||||
|
scalar_t=c_dtype,
|
||||||
|
w_type_id=scalar_type + ".id()",
|
||||||
|
threads=threads,
|
||||||
|
thread_m_blocks=max(m_blocks, 1),
|
||||||
|
thread_n_blocks=n_blocks,
|
||||||
|
thread_k_blocks=k_blocks,
|
||||||
|
m_block_size_8=m_blocks == 0.5,
|
||||||
|
stages="pipe_stages",
|
||||||
|
group_blocks=group_blocks,
|
||||||
|
is_zp_float=is_zp_float,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_template_str_list.append(template_str)
|
||||||
|
|
||||||
|
file_content = FILE_HEAD + "\n\n"
|
||||||
|
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||||
|
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
remove_old_kernels()
|
||||||
|
generate_new_kernels()
|
||||||
File diff suppressed because it is too large
Load Diff
37
csrc/quantization/gptq_marlin/kernel.h
Normal file
37
csrc/quantization/gptq_marlin/kernel.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
|
||||||
|
#ifndef MARLIN_NAMESPACE_NAME
|
||||||
|
#define MARLIN_NAMESPACE_NAME marlin
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "marlin.cuh"
|
||||||
|
#include "marlin_dtypes.cuh"
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
|
#define MARLIN_KERNEL_PARAMS \
|
||||||
|
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||||
|
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||||
|
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||||
|
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
|
||||||
|
int prob_k, int lda, int *locks, bool use_atomic_add, \
|
||||||
|
bool use_fp32_reduce, int max_shared_mem
|
||||||
|
|
||||||
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const int threads, // number of threads in a threadblock
|
||||||
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
|
const bool m_block_size_8, // whether m_block_size == 8
|
||||||
|
// only works when thread_m_blocks == 1
|
||||||
|
const int stages, // number of stages for the async global->shared
|
||||||
|
// fetch pipeline
|
||||||
|
const int group_blocks, // number of consecutive 16x16 blocks
|
||||||
|
// with a separate quantization scale
|
||||||
|
const bool is_zp_float // is zero point of float16 type?
|
||||||
|
>
|
||||||
|
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||||
|
|
||||||
|
}
|
||||||
1678
csrc/quantization/gptq_marlin/marlin_template.h
Normal file
1678
csrc/quantization/gptq_marlin/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -291,12 +291,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
|
||||||
"int b_q_type, "
|
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||||
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
|
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||||
"bool is_zp_float) -> Tensor",
|
|
||||||
{stride_tag});
|
{stride_tag});
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
@ -341,14 +340,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
|
||||||
ops.def(
|
|
||||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
||||||
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
|
||||||
"SymInt size_k) -> Tensor",
|
|
||||||
{stride_tag});
|
|
||||||
// conditionally compiled so impl registration is in source file
|
|
||||||
|
|
||||||
// marlin_qqq_gemm for QQQ.
|
// marlin_qqq_gemm for QQQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||||
|
|||||||
@ -11,19 +11,20 @@ from transformers import MixtralConfig
|
|||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
|
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||||
torch_moe_single)
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||||
fused_moe as iterative_moe)
|
fused_moe as iterative_moe)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
marlin_quant_fp8_torch)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
awq_marlin_quantize, marlin_quantize)
|
awq_marlin_quantize, marlin_quantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
quantize_weights)
|
quantize_weights)
|
||||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1, 4]
|
||||||
@ -285,7 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
|||||||
atol=mixtral_moe_tol[dtype])
|
atol=mixtral_moe_tol[dtype])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
@pytest.mark.parametrize("m", [1, 123, 666])
|
||||||
@pytest.mark.parametrize("n", [128, 1024])
|
@pytest.mark.parametrize("n", [128, 1024])
|
||||||
@pytest.mark.parametrize("k", [256, 2048])
|
@pytest.mark.parametrize("k", [256, 2048])
|
||||||
@pytest.mark.parametrize("e", [4, 12])
|
@pytest.mark.parametrize("e", [4, 12])
|
||||||
@ -294,8 +295,10 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
|||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||||
@pytest.mark.parametrize("act_order", [True, False])
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
@pytest.mark.parametrize("quant_type", [
|
||||||
@pytest.mark.parametrize("has_zp", [True, False])
|
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||||
|
scalar_types.float8_e4m3fn
|
||||||
|
])
|
||||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||||
def test_fused_marlin_moe(
|
def test_fused_marlin_moe(
|
||||||
@ -308,14 +311,22 @@ def test_fused_marlin_moe(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
num_bits: int,
|
quant_type: ScalarType,
|
||||||
has_zp: bool,
|
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
torch.cuda.manual_seed(0)
|
||||||
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||||
|
|
||||||
|
if quant_type == scalar_types.float8_e4m3fn:
|
||||||
|
if group_size not in [-1, 128]:
|
||||||
|
return
|
||||||
|
if act_order:
|
||||||
|
return
|
||||||
|
|
||||||
# Filter act_order
|
# Filter act_order
|
||||||
if act_order:
|
if act_order:
|
||||||
|
if quant_type == scalar_types.float8_e4m3fn:
|
||||||
|
return
|
||||||
if group_size == -1:
|
if group_size == -1:
|
||||||
return
|
return
|
||||||
if group_size in (k, n):
|
if group_size in (k, n):
|
||||||
@ -326,17 +337,9 @@ def test_fused_marlin_moe(
|
|||||||
if not is_k_full:
|
if not is_k_full:
|
||||||
return
|
return
|
||||||
|
|
||||||
if has_zp:
|
|
||||||
# we don't build kernel for int8 with zero
|
|
||||||
if num_bits == 8:
|
|
||||||
return
|
|
||||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
|
||||||
else:
|
|
||||||
quant_type = scalar_types.uint4b8 \
|
|
||||||
if num_bits == 4 else scalar_types.uint8b128
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||||
|
|
||||||
if ep_size > 1:
|
if ep_size > 1:
|
||||||
local_e = e // ep_size
|
local_e = e // ep_size
|
||||||
@ -364,17 +367,23 @@ def test_fused_marlin_moe(
|
|||||||
qweight1_l.append(qweight1)
|
qweight1_l.append(qweight1)
|
||||||
scales1_l.append(scales1)
|
scales1_l.append(scales1)
|
||||||
zeros1_l.append(zeros1)
|
zeros1_l.append(zeros1)
|
||||||
else:
|
elif quant_type != scalar_types.float8_e4m3fn:
|
||||||
test_perm = torch.randperm(k)
|
test_perm = torch.randperm(k)
|
||||||
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||||
group_size, act_order, test_perm)
|
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
|
group_size, act_order, test_perm)
|
||||||
|
|
||||||
w_ref1_l.append(w_ref1.T)
|
w_ref1_l.append(w_ref1.T)
|
||||||
qweight1_l.append(qweight1)
|
qweight1_l.append(qweight1)
|
||||||
scales1_l.append(scales1)
|
scales1_l.append(scales1)
|
||||||
g_idx1_l.append(g_idx1)
|
g_idx1_l.append(g_idx1)
|
||||||
sort_indices1_l.append(sort_indices1)
|
sort_indices1_l.append(sort_indices1)
|
||||||
|
else:
|
||||||
|
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||||
|
w1[i], group_size)
|
||||||
|
w_ref1_l.append(w_ref1.T)
|
||||||
|
qweight1_l.append(qweight1)
|
||||||
|
scales1_l.append(scales1)
|
||||||
|
|
||||||
w_ref1 = stack_and_dev(w_ref1_l)
|
w_ref1 = stack_and_dev(w_ref1_l)
|
||||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||||
@ -399,17 +408,23 @@ def test_fused_marlin_moe(
|
|||||||
qweight2_l.append(qweight2)
|
qweight2_l.append(qweight2)
|
||||||
scales2_l.append(scales2)
|
scales2_l.append(scales2)
|
||||||
zeros2_l.append(zeros2)
|
zeros2_l.append(zeros2)
|
||||||
else:
|
elif quant_type != scalar_types.float8_e4m3fn:
|
||||||
test_perm = torch.randperm(n)
|
test_perm = torch.randperm(n)
|
||||||
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||||
group_size, act_order, test_perm)
|
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
|
group_size, act_order, test_perm)
|
||||||
|
|
||||||
w_ref2_l.append(w_ref2.T)
|
w_ref2_l.append(w_ref2.T)
|
||||||
qweight2_l.append(qweight2)
|
qweight2_l.append(qweight2)
|
||||||
scales2_l.append(scales2)
|
scales2_l.append(scales2)
|
||||||
g_idx2_l.append(g_idx2)
|
g_idx2_l.append(g_idx2)
|
||||||
sort_indices2_l.append(sort_indices2)
|
sort_indices2_l.append(sort_indices2)
|
||||||
|
else:
|
||||||
|
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||||
|
w2[i], group_size)
|
||||||
|
w_ref2_l.append(w_ref2.T)
|
||||||
|
qweight2_l.append(qweight2)
|
||||||
|
scales2_l.append(scales2)
|
||||||
|
|
||||||
w_ref2 = stack_and_dev(w_ref2_l)
|
w_ref2 = stack_and_dev(w_ref2_l)
|
||||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||||
@ -442,102 +457,10 @@ def test_fused_marlin_moe(
|
|||||||
sort_indices2=sort_indices2,
|
sort_indices2=sort_indices2,
|
||||||
w1_zeros=zeros1,
|
w1_zeros=zeros1,
|
||||||
w2_zeros=zeros2,
|
w2_zeros=zeros2,
|
||||||
num_bits=num_bits,
|
quant_type_id=quant_type.id,
|
||||||
is_k_full=is_k_full)
|
is_k_full=is_k_full)
|
||||||
|
|
||||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
|
||||||
"don't run it in automated tests.")
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
|
||||||
@pytest.mark.parametrize("n", [128, 1024])
|
|
||||||
@pytest.mark.parametrize("k", [256, 2048])
|
|
||||||
@pytest.mark.parametrize("e", [4, 12])
|
|
||||||
@pytest.mark.parametrize("topk", [2, 3])
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
||||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
|
||||||
@pytest.mark.parametrize("act_order", [True, False])
|
|
||||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
|
||||||
@pytest.mark.parametrize("has_zp", [True, False])
|
|
||||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
|
||||||
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
|
|
||||||
dtype: torch.dtype, group_size: int,
|
|
||||||
act_order: bool, num_bits: int,
|
|
||||||
has_zp: bool, is_k_full: bool):
|
|
||||||
# Filter act_order
|
|
||||||
if act_order:
|
|
||||||
if group_size == -1:
|
|
||||||
return
|
|
||||||
if group_size in (k, n):
|
|
||||||
return
|
|
||||||
if has_zp:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
if not is_k_full:
|
|
||||||
return
|
|
||||||
|
|
||||||
if has_zp:
|
|
||||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
|
||||||
else:
|
|
||||||
quant_type = scalar_types.uint4b8 \
|
|
||||||
if num_bits == 4 else scalar_types.uint8b128
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
||||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
|
||||||
|
|
||||||
w_ref_l = []
|
|
||||||
qweight_l = []
|
|
||||||
scales_l = []
|
|
||||||
zeros_l = []
|
|
||||||
g_idx_l = []
|
|
||||||
sort_indices_l = []
|
|
||||||
|
|
||||||
for i in range(w.shape[0]):
|
|
||||||
if has_zp:
|
|
||||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
|
||||||
w[i].transpose(1, 0), quant_type, group_size)
|
|
||||||
|
|
||||||
w_ref_l.append(w_ref.T)
|
|
||||||
qweight_l.append(qweight)
|
|
||||||
scales_l.append(scales)
|
|
||||||
zeros_l.append(zeros)
|
|
||||||
else:
|
|
||||||
test_perm = torch.randperm(k)
|
|
||||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
|
||||||
w[i].transpose(1, 0), quant_type, group_size, act_order,
|
|
||||||
test_perm)
|
|
||||||
|
|
||||||
w_ref_l.append(w_ref.T)
|
|
||||||
qweight_l.append(qweight)
|
|
||||||
scales_l.append(scales)
|
|
||||||
g_idx_l.append(g_idx)
|
|
||||||
sort_indices_l.append(sort_indices)
|
|
||||||
|
|
||||||
w_ref = stack_and_dev(w_ref_l)
|
|
||||||
qweight = stack_and_dev(qweight_l).contiguous()
|
|
||||||
scales = stack_and_dev(scales_l)
|
|
||||||
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
|
||||||
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
|
||||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
||||||
marlin_output = torch.ops.vllm.single_marlin_moe(
|
|
||||||
a,
|
|
||||||
qweight,
|
|
||||||
scales,
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
renormalize=False,
|
|
||||||
g_idx=g_idx,
|
|
||||||
sort_indices=sort_indices,
|
|
||||||
w_zeros=zeros,
|
|
||||||
num_bits=num_bits,
|
|
||||||
is_k_full=is_k_full,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_output = torch_moe_single(a, w_ref, score, topk)
|
|
||||||
|
|
||||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_moe_align_block_size_opcheck():
|
def test_moe_align_block_size_opcheck():
|
||||||
|
|||||||
@ -1,164 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
"""Test AWQ with fused MoE Marlin kernels.
|
|
||||||
|
|
||||||
Run `pytest tests/kernels/test_awq_marlin.py`.
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
|
||||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
|
||||||
torch_moe_single)
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|
||||||
awq_marlin_quantize)
|
|
||||||
from vllm.scalar_type import scalar_types
|
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
|
||||||
TOP_KS = [2, 6]
|
|
||||||
GROUP_SIZES = [-1, 32, 128]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
|
||||||
@pytest.mark.parametrize("n", [128, 2048])
|
|
||||||
@pytest.mark.parametrize("k", [128, 1024])
|
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
|
||||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
|
||||||
@pytest.mark.skipif(not (ops.supports_moe_ops
|
|
||||||
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
|
|
||||||
reason="Marlin is not supported on this GPU type.")
|
|
||||||
def test_fused_marlin_moe_awq(
|
|
||||||
m: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
e: int,
|
|
||||||
topk: int,
|
|
||||||
group_size: int,
|
|
||||||
):
|
|
||||||
torch.manual_seed(7)
|
|
||||||
|
|
||||||
num_bits = 4
|
|
||||||
quant_type = scalar_types.uint4
|
|
||||||
dtype = torch.float16
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
||||||
|
|
||||||
w_ref1_l = []
|
|
||||||
qweights1_l = []
|
|
||||||
scales1_l = []
|
|
||||||
zp1_l = []
|
|
||||||
|
|
||||||
for i in range(w1.shape[0]):
|
|
||||||
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
|
|
||||||
w1[i].transpose(1, 0), quant_type, group_size)
|
|
||||||
w_ref1_l.append(w_ref1)
|
|
||||||
qweights1_l.append(qweight1)
|
|
||||||
scales1_l.append(scales1)
|
|
||||||
zp1_l.append(zp1)
|
|
||||||
|
|
||||||
w_ref1 = stack_and_dev(w_ref1_l)
|
|
||||||
qweight1 = stack_and_dev(qweights1_l).contiguous()
|
|
||||||
scales1 = stack_and_dev(scales1_l)
|
|
||||||
zp1 = stack_and_dev(zp1_l)
|
|
||||||
|
|
||||||
w_ref2_l = []
|
|
||||||
qweights2_l = []
|
|
||||||
scales2_l = []
|
|
||||||
zp2_l = []
|
|
||||||
|
|
||||||
for i in range(w2.shape[0]):
|
|
||||||
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
|
|
||||||
w2[i].transpose(1, 0), quant_type, group_size)
|
|
||||||
w_ref2_l.append(w_ref2)
|
|
||||||
qweights2_l.append(qweight2)
|
|
||||||
scales2_l.append(scales2)
|
|
||||||
zp2_l.append(zp2)
|
|
||||||
|
|
||||||
w_ref2 = stack_and_dev(w_ref2_l)
|
|
||||||
qweight2 = stack_and_dev(qweights2_l).contiguous()
|
|
||||||
scales2 = stack_and_dev(scales2_l)
|
|
||||||
zp2 = stack_and_dev(zp2_l)
|
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
||||||
a, score, topk, False)
|
|
||||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
|
||||||
a,
|
|
||||||
qweight1,
|
|
||||||
qweight2,
|
|
||||||
scales1,
|
|
||||||
scales2,
|
|
||||||
score,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
w1_zeros=zp1,
|
|
||||||
w2_zeros=zp2,
|
|
||||||
num_bits=num_bits,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
|
|
||||||
score, topk, None)
|
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
|
||||||
"don't run it in automated tests.")
|
|
||||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
|
||||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
|
||||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
|
||||||
@pytest.mark.parametrize("e", [8, 64])
|
|
||||||
@pytest.mark.parametrize("topk", [2, 6])
|
|
||||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
|
||||||
def test_single_marlin_moe_multiply_awq(
|
|
||||||
m: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
e: int,
|
|
||||||
topk: int,
|
|
||||||
group_size: int,
|
|
||||||
):
|
|
||||||
torch.manual_seed(7)
|
|
||||||
|
|
||||||
num_bits = 4
|
|
||||||
quant_type = scalar_types.uint4
|
|
||||||
dtype = torch.float16
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
||||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
|
||||||
|
|
||||||
w_ref_l = []
|
|
||||||
qweights_l = []
|
|
||||||
scales_l = []
|
|
||||||
zp_l = []
|
|
||||||
|
|
||||||
for i in range(w.shape[0]):
|
|
||||||
w_ref, qweight, scales, zp = awq_marlin_quantize(
|
|
||||||
w[i].transpose(1, 0), quant_type, group_size)
|
|
||||||
w_ref_l.append(w_ref)
|
|
||||||
qweights_l.append(qweight)
|
|
||||||
scales_l.append(scales)
|
|
||||||
zp_l.append(zp)
|
|
||||||
|
|
||||||
w_ref = stack_and_dev(w_ref_l)
|
|
||||||
qweight = stack_and_dev(qweights_l).contiguous()
|
|
||||||
scales = stack_and_dev(scales_l).contiguous()
|
|
||||||
zp = stack_and_dev(zp_l).contiguous()
|
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
||||||
|
|
||||||
marlin_output = torch.ops.vllm.single_marlin_moe(a,
|
|
||||||
qweight,
|
|
||||||
scales,
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
renormalize=False,
|
|
||||||
w_zeros=zp,
|
|
||||||
num_bits=num_bits)
|
|
||||||
|
|
||||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
|
||||||
@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.qqq import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||||
marlin_permute_scales, query_marlin_supported_quant_types)
|
marlin_make_workspace_new, marlin_permute_scales,
|
||||||
|
query_marlin_supported_quant_types)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
pack_fp8_to_int32)
|
marlin_quant_fp8_torch)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||||
marlin_weights)
|
marlin_weights)
|
||||||
@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16):
|
|||||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
@pytest.mark.parametrize("quant_type",
|
@pytest.mark.parametrize("quant_type",
|
||||||
query_marlin_supported_quant_types(False))
|
query_marlin_supported_quant_types(False, False))
|
||||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
@pytest.mark.parametrize("quant_type",
|
@pytest.mark.parametrize("quant_type",
|
||||||
query_marlin_supported_quant_types(False))
|
query_marlin_supported_quant_types(True))
|
||||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||||
@ -220,38 +221,50 @@ def test_gptq_marlin_gemm(
|
|||||||
if group_size == size_k:
|
if group_size == size_k:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if size_k % group_size != 0:
|
||||||
|
return
|
||||||
|
|
||||||
a_input = rand_data((size_m, size_k))
|
a_input = rand_data((size_m, size_k))
|
||||||
b_weight = rand_data((size_k, size_n))
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
|
||||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
if quant_type == scalar_types.float8_e4m3fn:
|
||||||
b_weight, quant_type, group_size, act_order)
|
if group_size not in [-1, 128]:
|
||||||
|
return
|
||||||
|
if act_order:
|
||||||
|
return
|
||||||
|
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||||
|
b_weight.T, group_size)
|
||||||
|
g_idx = None
|
||||||
|
sort_indices = None
|
||||||
|
else:
|
||||||
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||||
|
b_weight, quant_type, group_size, act_order)
|
||||||
|
|
||||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||||
|
|
||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
workspace = marlin_make_workspace_new(w_ref.device)
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
|
||||||
|
|
||||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
opcheck(
|
||||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
torch.ops._C.gptq_marlin_gemm,
|
||||||
workspace.scratch, quant_type.id, a_input.shape[0],
|
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||||
b_weight.shape[1], a_input.shape[1], is_k_full, False,
|
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||||
use_atomic_add, use_fp32_reduce, False),
|
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
|
||||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(
|
output = ops.gptq_marlin_gemm(
|
||||||
a_input,
|
a_input,
|
||||||
|
None,
|
||||||
marlin_q_w,
|
marlin_q_w,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
g_idx,
|
g_idx,
|
||||||
sort_indices,
|
sort_indices,
|
||||||
workspace.scratch,
|
workspace,
|
||||||
quant_type,
|
quant_type,
|
||||||
a_input.shape[0],
|
a_input.shape[0],
|
||||||
b_weight.shape[1],
|
b_weight.shape[1],
|
||||||
a_input.shape[1],
|
a_input.shape[1],
|
||||||
is_k_full=is_k_full,
|
is_k_full=is_k_full,
|
||||||
has_zp=False,
|
|
||||||
use_atomic_add=use_atomic_add,
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
assert max_diff < 0.04
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
|
||||||
reason="Marlin is not supported on this GPU type.")
|
|
||||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
||||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
||||||
@pytest.mark.parametrize("num_bits", [8])
|
|
||||||
@pytest.mark.parametrize("group_size", [-1])
|
|
||||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
def test_fp8_marlin_gemm(
|
|
||||||
k_chunk,
|
|
||||||
n_chunk,
|
|
||||||
num_bits,
|
|
||||||
group_size,
|
|
||||||
mnk_factors,
|
|
||||||
dtype,
|
|
||||||
):
|
|
||||||
m_factor, n_factor, k_factor = mnk_factors
|
|
||||||
|
|
||||||
size_m = m_factor
|
|
||||||
size_k = k_chunk * k_factor
|
|
||||||
size_n = n_chunk * n_factor
|
|
||||||
|
|
||||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
|
||||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
|
|
||||||
# Repack weights to gptq format (packed int32 elements)
|
|
||||||
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
|
||||||
# Repack weights to marlin format
|
|
||||||
marlin_qweight = ops.gptq_marlin_repack(
|
|
||||||
b_q_weight=packed_gptq_qweight,
|
|
||||||
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
|
||||||
size_k=size_k,
|
|
||||||
size_n=size_n,
|
|
||||||
num_bits=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# WEIGHT SCALES
|
|
||||||
# Currently Marlin doesn't support per-tensor scales, so we
|
|
||||||
# expand it to channelwise
|
|
||||||
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
|
||||||
# Permute scales
|
|
||||||
marlin_scales = marlin_permute_scales(s=scales,
|
|
||||||
size_k=size_k,
|
|
||||||
size_n=size_n,
|
|
||||||
group_size=-1)
|
|
||||||
|
|
||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
|
||||||
|
|
||||||
opcheck(torch.ops._C.fp8_marlin_gemm,
|
|
||||||
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
|
|
||||||
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
|
|
||||||
|
|
||||||
output = ops.fp8_marlin_gemm(
|
|
||||||
a=a_input,
|
|
||||||
b_q_weight=marlin_qweight,
|
|
||||||
b_scales=marlin_scales,
|
|
||||||
workspace=workspace.scratch,
|
|
||||||
num_bits=num_bits,
|
|
||||||
size_m=a_input.shape[0],
|
|
||||||
size_n=b_weight.shape[1],
|
|
||||||
size_k=a_input.shape[1],
|
|
||||||
)
|
|
||||||
output_ref = torch.matmul(a_input, b_weight)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
max_diff = compute_max_diff(output, output_ref)
|
|
||||||
|
|
||||||
assert max_diff < 0.04
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||||
reason="Marlin is not supported on this GPU type.")
|
reason="Marlin is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
@ -432,25 +371,23 @@ def test_awq_marlin_gemm(
|
|||||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||||
is_k_full = True
|
is_k_full = True
|
||||||
has_zp = True
|
|
||||||
|
|
||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
workspace = marlin_make_workspace_new(a_input.device)
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(
|
output = ops.gptq_marlin_gemm(
|
||||||
a_input,
|
a_input,
|
||||||
|
None,
|
||||||
marlin_q_w,
|
marlin_q_w,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
g_idx,
|
g_idx,
|
||||||
sort_indices,
|
sort_indices,
|
||||||
workspace.scratch,
|
workspace,
|
||||||
quant_type,
|
quant_type,
|
||||||
a_input.shape[0],
|
a_input.shape[0],
|
||||||
b_weight.shape[1],
|
b_weight.shape[1],
|
||||||
a_input.shape[1],
|
a_input.shape[1],
|
||||||
is_k_full=is_k_full,
|
is_k_full=is_k_full,
|
||||||
has_zp=has_zp,
|
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
)
|
)
|
||||||
@ -508,23 +445,22 @@ def test_hqq_marlin_gemm(
|
|||||||
g_idx = marlin_make_empty_g_idx(dev)
|
g_idx = marlin_make_empty_g_idx(dev)
|
||||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||||
|
|
||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
workspace = marlin_make_workspace_new(b_weight.device)
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(
|
output = ops.gptq_marlin_gemm(
|
||||||
a_input,
|
a_input,
|
||||||
|
None,
|
||||||
marlin_w_q,
|
marlin_w_q,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
g_idx,
|
g_idx,
|
||||||
g_idx_sort_indices,
|
g_idx_sort_indices,
|
||||||
workspace.scratch,
|
workspace,
|
||||||
quant_type,
|
quant_type,
|
||||||
a_input.shape[0],
|
a_input.shape[0],
|
||||||
b_weight.shape[0],
|
b_weight.shape[0],
|
||||||
a_input.shape[1],
|
a_input.shape[1],
|
||||||
is_k_full=True,
|
is_k_full=True,
|
||||||
has_zp=True,
|
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=True,
|
is_zp_float=True,
|
||||||
)
|
)
|
||||||
@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input():
|
|||||||
b_weight, quant_type, group_size, False)
|
b_weight, quant_type, group_size, False)
|
||||||
|
|
||||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
workspace = marlin_make_workspace_new(a_input.device)
|
||||||
GPTQ_MARLIN_MAX_PARALLEL)
|
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(
|
output = ops.gptq_marlin_gemm(
|
||||||
a_input,
|
a_input,
|
||||||
|
None,
|
||||||
marlin_q_w,
|
marlin_q_w,
|
||||||
marlin_s,
|
marlin_s,
|
||||||
marlin_zp,
|
marlin_zp,
|
||||||
g_idx,
|
g_idx,
|
||||||
sort_indices,
|
sort_indices,
|
||||||
workspace.scratch,
|
workspace,
|
||||||
quant_type,
|
quant_type,
|
||||||
a_input.shape[0],
|
a_input.shape[0],
|
||||||
b_weight.shape[1],
|
b_weight.shape[1],
|
||||||
a_input.shape[1],
|
a_input.shape[1],
|
||||||
is_k_full=True,
|
is_k_full=True,
|
||||||
has_zp=False,
|
|
||||||
use_atomic_add=False,
|
use_atomic_add=False,
|
||||||
use_fp32_reduce=True,
|
use_fp32_reduce=True,
|
||||||
is_zp_float=False,
|
is_zp_float=False,
|
||||||
|
|||||||
@ -325,18 +325,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
|
|
||||||
@register_fake("_C::gptq_marlin_gemm")
|
@register_fake("_C::gptq_marlin_gemm")
|
||||||
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
||||||
|
c: Optional[torch.Tensor],
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor,
|
b_scales: torch.Tensor,
|
||||||
b_zeros: torch.Tensor,
|
b_zeros: Optional[torch.Tensor],
|
||||||
g_idx: torch.Tensor,
|
g_idx: Optional[torch.Tensor],
|
||||||
perm: torch.Tensor,
|
perm: Optional[torch.Tensor],
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
b_q_type: ScalarType,
|
b_q_type_id: int,
|
||||||
size_m: torch.SymInt,
|
size_m: torch.SymInt,
|
||||||
size_n: torch.SymInt,
|
size_n: torch.SymInt,
|
||||||
size_k: torch.SymInt,
|
size_k: torch.SymInt,
|
||||||
is_k_full: bool,
|
is_k_full: bool = True,
|
||||||
has_zp: bool = False,
|
|
||||||
use_atomic_add: bool = False,
|
use_atomic_add: bool = False,
|
||||||
use_fp32_reduce: bool = False,
|
use_fp32_reduce: bool = False,
|
||||||
is_zp_float: bool = False) -> torch.Tensor:
|
is_zp_float: bool = False) -> torch.Tensor:
|
||||||
@ -407,14 +407,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
|||||||
dtype=codebooks.dtype,
|
dtype=codebooks.dtype,
|
||||||
device=codebooks.device)
|
device=codebooks.device)
|
||||||
|
|
||||||
@register_fake("_C::fp8_marlin_gemm")
|
|
||||||
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
|
||||||
num_bits: int, size_m: torch.SymInt,
|
|
||||||
size_n: torch.SymInt,
|
|
||||||
size_k: torch.SymInt) -> torch.Tensor:
|
|
||||||
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
|
||||||
|
|
||||||
@register_fake("_C::machete_mm")
|
@register_fake("_C::machete_mm")
|
||||||
def machete_mm_fake(
|
def machete_mm_fake(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@ -815,35 +807,26 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def gptq_marlin_gemm(a: torch.Tensor,
|
def gptq_marlin_gemm(a: torch.Tensor,
|
||||||
|
c: Optional[torch.Tensor],
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
b_scales: torch.Tensor,
|
b_scales: torch.Tensor,
|
||||||
b_zeros: torch.Tensor,
|
b_zeros: Optional[torch.Tensor],
|
||||||
g_idx: torch.Tensor,
|
g_idx: Optional[torch.Tensor],
|
||||||
perm: torch.Tensor,
|
perm: Optional[torch.Tensor],
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
b_q_type: ScalarType,
|
b_q_type: ScalarType,
|
||||||
size_m: int,
|
size_m: int,
|
||||||
size_n: int,
|
size_n: int,
|
||||||
size_k: int,
|
size_k: int,
|
||||||
is_k_full: bool,
|
is_k_full: bool = True,
|
||||||
has_zp: bool = False,
|
|
||||||
use_atomic_add: bool = False,
|
use_atomic_add: bool = False,
|
||||||
use_fp32_reduce: bool = False,
|
use_fp32_reduce: bool = False,
|
||||||
is_zp_float: bool = False) -> torch.Tensor:
|
is_zp_float: bool = False) -> torch.Tensor:
|
||||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros,
|
||||||
g_idx, perm, workspace, b_q_type.id,
|
g_idx, perm, workspace, b_q_type.id,
|
||||||
size_m, size_n, size_k, is_k_full,
|
size_m, size_n, size_k, is_k_full,
|
||||||
has_zp, use_atomic_add,
|
use_atomic_add, use_fp32_reduce,
|
||||||
use_fp32_reduce, is_zp_float)
|
is_zp_float)
|
||||||
|
|
||||||
|
|
||||||
# fp8 marlin
|
|
||||||
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|
||||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
|
||||||
num_bits: int, size_m: int, size_n: int,
|
|
||||||
size_k: int) -> torch.Tensor:
|
|
||||||
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
|
||||||
num_bits, size_m, size_n, size_k)
|
|
||||||
|
|
||||||
|
|
||||||
# machete
|
# machete
|
||||||
|
|||||||
@ -7,163 +7,13 @@ import torch
|
|||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
moe_align_block_size, try_get_optimal_moe_config)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
|
||||||
if has_zp:
|
|
||||||
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
|
||||||
else:
|
|
||||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
|
||||||
|
|
||||||
|
|
||||||
def single_marlin_moe(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w: torch.Tensor,
|
|
||||||
scales: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
g_idx: Optional[torch.Tensor] = None,
|
|
||||||
sort_indices: Optional[torch.Tensor] = None,
|
|
||||||
w_zeros: Optional[torch.Tensor] = None,
|
|
||||||
workspace: Optional[torch.Tensor] = None,
|
|
||||||
num_bits: int = 8,
|
|
||||||
is_k_full: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function computes the multiplication of hidden_states with expert
|
|
||||||
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
|
||||||
Its purpose is testing and debugging the fused MoE kernel.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
|
|
||||||
- w (torch.Tensor): The set of expert weights.
|
|
||||||
- scales (torch.Tensor): The quantization scales.
|
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
|
||||||
(before softmax).
|
|
||||||
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
|
|
||||||
- sort_indices (Optional[torch.Tensor]): Optional act_order input
|
|
||||||
permutation.
|
|
||||||
- topk (int): The number of top-k experts to select.
|
|
||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
|
||||||
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
|
||||||
- num_bits (bool): The number of bits in expert weights quantization.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
|
||||||
"""
|
|
||||||
# Check constraints.
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
|
||||||
"Number of tokens mismatch")
|
|
||||||
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
|
|
||||||
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w.is_contiguous(), "Expert weights must be contiguous"
|
|
||||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
|
||||||
assert num_bits in [4, 8]
|
|
||||||
|
|
||||||
M, K = hidden_states.shape
|
|
||||||
E = w.shape[0]
|
|
||||||
N = w.shape[2] // (num_bits // 2)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
||||||
hidden_states, gating_output, topk, renormalize)
|
|
||||||
|
|
||||||
# This might not be an optimal config for a single MMM
|
|
||||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
|
||||||
w.shape,
|
|
||||||
w.shape,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
None,
|
|
||||||
is_marlin=True)
|
|
||||||
config = get_config_func(M)
|
|
||||||
|
|
||||||
block_size_m = config['BLOCK_SIZE_M']
|
|
||||||
|
|
||||||
if global_num_experts == -1:
|
|
||||||
global_num_experts = E
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
|
||||||
moe_align_block_size(topk_ids, block_size_m, E, expert_map)
|
|
||||||
|
|
||||||
if workspace is None:
|
|
||||||
max_workspace_size = (max(2 * N, K) // 64) * \
|
|
||||||
(sorted_token_ids.size(0) // block_size_m)
|
|
||||||
device = hidden_states.device
|
|
||||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
|
||||||
max_workspace_size = min(max_workspace_size, sms)
|
|
||||||
workspace = torch.zeros(max_workspace_size,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
scalar_type = get_scalar_type(num_bits, w_zeros is not None)
|
|
||||||
intermediate_cache = torch.empty(
|
|
||||||
(M * topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.moe_wna16_marlin_gemm(hidden_states,
|
|
||||||
intermediate_cache,
|
|
||||||
w,
|
|
||||||
scales,
|
|
||||||
w_zeros,
|
|
||||||
g_idx,
|
|
||||||
sort_indices,
|
|
||||||
workspace,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
topk_weights,
|
|
||||||
moe_block_size=block_size_m,
|
|
||||||
top_k=topk,
|
|
||||||
mul_topk_weights=False,
|
|
||||||
is_ep=expert_map is not None,
|
|
||||||
b_q_type=scalar_type,
|
|
||||||
size_m=M,
|
|
||||||
size_n=N,
|
|
||||||
size_k=K,
|
|
||||||
is_k_full=is_k_full,
|
|
||||||
use_atomic_add=False,
|
|
||||||
use_fp32_reduce=True,
|
|
||||||
is_zp_float=False)
|
|
||||||
intermediate_cache = intermediate_cache.view(-1, topk, N)
|
|
||||||
|
|
||||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def single_marlin_moe_fake(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w: torch.Tensor,
|
|
||||||
scales: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
g_idx: Optional[torch.Tensor] = None,
|
|
||||||
sort_indices: Optional[torch.Tensor] = None,
|
|
||||||
w_zeros: Optional[torch.Tensor] = None,
|
|
||||||
workspace: Optional[torch.Tensor] = None,
|
|
||||||
num_bits: int = 8,
|
|
||||||
is_k_full: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="single_marlin_moe",
|
|
||||||
op_func=single_marlin_moe,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=single_marlin_moe_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -172,6 +22,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
quant_type_id: int,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
g_idx1: Optional[torch.Tensor] = None,
|
g_idx1: Optional[torch.Tensor] = None,
|
||||||
@ -181,7 +32,6 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
w1_zeros: Optional[torch.Tensor] = None,
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
w2_zeros: Optional[torch.Tensor] = None,
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
workspace: Optional[torch.Tensor] = None,
|
workspace: Optional[torch.Tensor] = None,
|
||||||
num_bits: int = 8,
|
|
||||||
is_k_full: bool = True,
|
is_k_full: bool = True,
|
||||||
inplace: bool = False) -> torch.Tensor:
|
inplace: bool = False) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -211,6 +61,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
|
quant_type = ScalarType.from_id(quant_type_id)
|
||||||
|
assert quant_type in [
|
||||||
|
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||||
|
scalar_types.float8_e4m3fn
|
||||||
|
]
|
||||||
|
|
||||||
|
int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8]
|
||||||
|
num_bits = 4 if quant_type in int4_scalar_types else 8
|
||||||
|
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[0] == gating_output.shape[
|
assert hidden_states.shape[0] == gating_output.shape[
|
||||||
0], "Number of tokens mismatch"
|
0], "Number of tokens mismatch"
|
||||||
@ -248,18 +107,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
expert_map)
|
expert_map)
|
||||||
|
|
||||||
if workspace is None:
|
if workspace is None:
|
||||||
max_workspace_size = (max(2 * N, K) // 64) * \
|
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||||
(sorted_token_ids.size(0) // block_size_m)
|
|
||||||
device = hidden_states.device
|
|
||||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
|
||||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
|
||||||
workspace = torch.zeros(max_workspace_size,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
|
||||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
|
||||||
|
|
||||||
intermediate_cache2 = torch.empty(
|
intermediate_cache2 = torch.empty(
|
||||||
(M * topk_ids.shape[1], N),
|
(M * topk_ids.shape[1], N),
|
||||||
@ -276,6 +124,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
|
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
|
||||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||||
|
|
||||||
|
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||||
use_atomic_add = hidden_states.dtype == torch.half or \
|
use_atomic_add = hidden_states.dtype == torch.half or \
|
||||||
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||||
|
|
||||||
@ -296,7 +145,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
top_k=topk,
|
top_k=topk,
|
||||||
mul_topk_weights=False,
|
mul_topk_weights=False,
|
||||||
is_ep=expert_map is not None,
|
is_ep=expert_map is not None,
|
||||||
b_q_type=scalar_type1,
|
b_q_type=quant_type,
|
||||||
size_m=M,
|
size_m=M,
|
||||||
size_n=2 * N,
|
size_n=2 * N,
|
||||||
size_k=K,
|
size_k=K,
|
||||||
@ -328,7 +177,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
|
|||||||
top_k=1,
|
top_k=1,
|
||||||
mul_topk_weights=True,
|
mul_topk_weights=True,
|
||||||
is_ep=expert_map is not None,
|
is_ep=expert_map is not None,
|
||||||
b_q_type=scalar_type2,
|
b_q_type=quant_type,
|
||||||
size_m=M * topk,
|
size_m=M * topk,
|
||||||
size_n=K,
|
size_n=K,
|
||||||
size_k=N,
|
size_k=N,
|
||||||
@ -351,6 +200,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
quant_type_id: int,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
g_idx1: Optional[torch.Tensor] = None,
|
g_idx1: Optional[torch.Tensor] = None,
|
||||||
@ -360,7 +210,6 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
|||||||
w1_zeros: Optional[torch.Tensor] = None,
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
w2_zeros: Optional[torch.Tensor] = None,
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
workspace: Optional[torch.Tensor] = None,
|
workspace: Optional[torch.Tensor] = None,
|
||||||
num_bits: int = 8,
|
|
||||||
is_k_full: bool = True,
|
is_k_full: bool = True,
|
||||||
inplace: bool = False) -> torch.Tensor:
|
inplace: bool = False) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|||||||
@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
marlin_make_empty_g_idx, marlin_make_workspace_new,
|
||||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
marlin_moe_permute_scales, marlin_permute_scales,
|
||||||
verify_marlin_supported, verify_marlin_supports_shape)
|
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||||
|
verify_marlin_supports_shape)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
PackedvLLMParameter)
|
PackedvLLMParameter)
|
||||||
@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
# Allocate marlin workspace
|
# Allocate marlin workspace
|
||||||
layer.workspace = marlin_make_workspace(
|
layer.workspace = marlin_make_workspace_new(device)
|
||||||
layer.output_size_per_partition, device)
|
|
||||||
|
|
||||||
# Repack weights from AWQ format to marlin format.
|
# Repack weights from AWQ format to marlin format.
|
||||||
marlin_qweight = ops.awq_marlin_repack(
|
marlin_qweight = ops.awq_marlin_repack(
|
||||||
@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: AWQMarlinConfig):
|
def __init__(self, quant_config: AWQMarlinConfig):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
if self.quant_config.weight_bits != 4:
|
||||||
|
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||||
|
self.quant_type = scalar_types.uint4
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||||
|
|
||||||
device = layer.w13_qweight.device
|
device = layer.w13_qweight.device
|
||||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||||
layer.workspace = torch.zeros((sms * 4, ),
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
num_experts = layer.w13_qweight.shape[0]
|
num_experts = layer.w13_qweight.shape[0]
|
||||||
@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
router_logits,
|
router_logits,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
quant_type_id=self.quant_type.id,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_zeros=layer.w13_qzeros,
|
w1_zeros=layer.w13_qzeros,
|
||||||
w2_zeros=layer.w2_qzeros,
|
w2_zeros=layer.w2_qzeros,
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace)
|
||||||
num_bits=self.quant_config.weight_bits,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
|||||||
# required by torch.compile to be torch.nn.Parameter
|
# required by torch.compile to be torch.nn.Parameter
|
||||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
@ -68,6 +68,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
|||||||
layer.input_size_per_partition = input_size_per_partition
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
layer.orig_dtype = params_dtype
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
weight = ModelWeightParameter(data=torch.empty(
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
|||||||
@ -21,19 +21,21 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||||
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
|
||||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
cutlass_fp8_supported, maybe_create_device_identity,
|
||||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||||
per_tensor_dequantize, requantize_with_max_scale)
|
requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
@ -181,10 +183,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
if self.block_quant:
|
|
||||||
# Marlin doesn't support block-wise fp8
|
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
# Default to using per_token quantization if cutlass is supported
|
# Default to using per_token quantization if cutlass is supported
|
||||||
use_per_token_if_dynamic=cutlass_fp8_supported())
|
use_per_token_if_dynamic=cutlass_fp8_supported())
|
||||||
@ -203,10 +201,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
assert self.quant_config.weight_block_size is not None
|
assert self.quant_config.weight_block_size is not None
|
||||||
|
layer.weight_block_size = self.quant_config.weight_block_size
|
||||||
block_n, block_k = (
|
block_n, block_k = (
|
||||||
self.quant_config.weight_block_size[0],
|
self.quant_config.weight_block_size[0],
|
||||||
self.quant_config.weight_block_size[1],
|
self.quant_config.weight_block_size[1],
|
||||||
@ -229,12 +233,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
f"{output_partition_size} is not divisible by "
|
f"{output_partition_size} is not divisible by "
|
||||||
f"weight quantization block_n = {block_n}.")
|
f"weight quantization block_n = {block_n}.")
|
||||||
|
|
||||||
layer.logical_widths = output_partition_sizes
|
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
|
||||||
layer.orig_dtype = params_dtype
|
|
||||||
|
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
weight_dtype = (torch.float8_e4m3fn
|
weight_dtype = (torch.float8_e4m3fn
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||||
@ -303,9 +301,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
size_k_first = True
|
||||||
# TODO(rob): refactor block quant into separate class.
|
# TODO(rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
|
size_k_first = False
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
weight, weight_scale_inv, _ = \
|
weight, weight_scale_inv, _ = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@ -321,21 +321,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight = Parameter(weight, requires_grad=False)
|
layer.weight = Parameter(weight, requires_grad=False)
|
||||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||||
scale=None)
|
scale=None)
|
||||||
|
|
||||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
|
||||||
# so extend the weight scales to be channelwise.
|
|
||||||
if self.use_marlin:
|
|
||||||
assert weight_scale.numel() == 1
|
|
||||||
weight_scale = convert_to_channelwise(
|
|
||||||
weight_scale.expand(len(layer.logical_widths)),
|
|
||||||
layer.logical_widths)
|
|
||||||
|
|
||||||
# Update the layer with the new values.
|
# Update the layer with the new values.
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
@ -349,20 +340,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
|
||||||
# so extend the weight scales to be channelwise.
|
weight = layer.weight
|
||||||
if self.use_marlin:
|
weight_scale = layer.weight_scale
|
||||||
weight = layer.weight
|
|
||||||
weight_scale = convert_to_channelwise(layer.weight_scale,
|
|
||||||
layer.logical_widths)
|
|
||||||
|
|
||||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||||
# requantize the logical shards as a single weight.
|
# requantize the logical shards as a single weight.
|
||||||
else:
|
if not self.use_marlin:
|
||||||
# Dequant -> Quant with max scale so we can run per tensor.
|
# Dequant -> Quant with max scale so we can run per tensor.
|
||||||
weight = layer.weight
|
|
||||||
weight_scale = layer.weight_scale
|
|
||||||
|
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
weight, weight_scale, input_scale = \
|
weight, weight_scale, input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@ -388,7 +373,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_fp8_layer_for_marlin(layer)
|
prepare_fp8_layer_for_marlin(layer, size_k_first)
|
||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
del layer.input_scale
|
del layer.input_scale
|
||||||
|
|
||||||
@ -444,6 +429,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
|
# kernel for fast weight-only FP8 quantization
|
||||||
|
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||||
|
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||||
|
# Disable marlin for rocm
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
# Check for DeepGemm support.
|
# Check for DeepGemm support.
|
||||||
self.allow_deep_gemm = False
|
self.allow_deep_gemm = False
|
||||||
if envs.VLLM_USE_DEEP_GEMM:
|
if envs.VLLM_USE_DEEP_GEMM:
|
||||||
@ -461,10 +454,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
|
layer.hidden_size = hidden_size
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.weight_block_size is not None
|
assert self.quant_config.weight_block_size is not None
|
||||||
|
layer.weight_block_size = self.quant_config.weight_block_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
block_n, block_k = (
|
block_n, block_k = (
|
||||||
self.quant_config.weight_block_size[0],
|
self.quant_config.weight_block_size[0],
|
||||||
@ -630,10 +630,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight_scale_inv = \
|
layer.w2_weight_scale_inv = \
|
||||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
fp8_dtype = current_platform.fp8_dtype()
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||||
dtype=fp8_dtype)
|
dtype=fp8_dtype)
|
||||||
@ -677,8 +675,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
# If checkpoint is fp8, we need to handle that the
|
||||||
# MoE kernels require single activation scale and single weight
|
# MoE kernels require single activation scale and single weight
|
||||||
# scale for w13 per expert.
|
# scale for w13 per expert.
|
||||||
@ -766,7 +762,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
return
|
|
||||||
|
if self.use_marlin:
|
||||||
|
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||||
|
# Activations not quantized for marlin.
|
||||||
|
del layer.w13_input_scale
|
||||||
|
del layer.w2_input_scale
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -801,6 +802,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_marlin:
|
||||||
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
router_logits,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@ -21,8 +21,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
|||||||
get_linear_quant_method)
|
get_linear_quant_method)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||||
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
|
marlin_make_workspace_new, marlin_moe_permute_scales,
|
||||||
verify_marlin_supported)
|
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
if self.quant_config.quant_type.size_bits == 4:
|
||||||
|
self.quant_type = scalar_types.uint4b8
|
||||||
|
elif self.quant_config.quant_type.size_bits == 8:
|
||||||
|
self.quant_type = scalar_types.uint8b128
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -498,11 +505,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||||
|
|
||||||
device = layer.w13_qweight.device
|
device = layer.w13_qweight.device
|
||||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||||
layer.workspace = torch.zeros((sms * 4, ),
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
@ -633,12 +636,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
router_logits,
|
router_logits,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
quant_type_id=self.quant_type.id,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
g_idx1=layer.w13_g_idx,
|
g_idx1=layer.w13_g_idx,
|
||||||
g_idx2=layer.w2_g_idx,
|
g_idx2=layer.w2_g_idx,
|
||||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||||
num_bits=self.quant_config.quant_type.size_bits,
|
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
is_k_full=self.is_k_full)
|
is_k_full=self.is_k_full)
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx,
|
||||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
permute_param_layout_)
|
permute_param_layout_)
|
||||||
@ -53,8 +53,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||||
|
|
||||||
# Allocate marlin workspace.
|
# Allocate marlin workspace.
|
||||||
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
|
self.workspace = marlin_make_workspace_new(device)
|
||||||
device)
|
|
||||||
|
|
||||||
# Default names since marlin requires empty parameters for these,
|
# Default names since marlin requires empty parameters for these,
|
||||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||||
@ -127,6 +126,5 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
wtype=c.weight_type,
|
wtype=c.weight_type,
|
||||||
input_size_per_partition=c.partition_weight_shape[0],
|
input_size_per_partition=c.partition_weight_shape[0],
|
||||||
output_size_per_partition=c.partition_weight_shape[1],
|
output_size_per_partition=c.partition_weight_shape[1],
|
||||||
has_zp=self.config.zero_points,
|
|
||||||
is_k_full=self.is_k_full,
|
is_k_full=self.is_k_full,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
|||||||
@ -7,12 +7,15 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
from .quant_utils import pack_cols, unpack_cols
|
from .quant_utils import pack_cols, unpack_cols
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
GPTQ_MARLIN_TILE = 16
|
GPTQ_MARLIN_TILE = 16
|
||||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||||
@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
|
|||||||
# For binary size and compile time, we don't support the same types for with and
|
# For binary size and compile time, we don't support the same types for with and
|
||||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||||
def query_marlin_supported_quant_types(has_zp: bool,
|
def query_marlin_supported_quant_types(
|
||||||
device_capability: Optional[int] = None
|
has_zp: bool,
|
||||||
):
|
include_fp_type: bool = True,
|
||||||
|
device_capability: Optional[int] = None,
|
||||||
|
):
|
||||||
if device_capability is None:
|
if device_capability is None:
|
||||||
capability_tuple = current_platform.get_device_capability()
|
capability_tuple = current_platform.get_device_capability()
|
||||||
device_capability = (-1 if capability_tuple is None else
|
device_capability = (-1 if capability_tuple is None else
|
||||||
@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
|
|||||||
|
|
||||||
if has_zp:
|
if has_zp:
|
||||||
# AWQ style, unsigned + runtime zero-point
|
# AWQ style, unsigned + runtime zero-point
|
||||||
return [scalar_types.uint4, scalar_types.uint8]
|
return [scalar_types.uint4]
|
||||||
else:
|
else:
|
||||||
# GPTQ style, unsigned + symmetric bias
|
# GPTQ style, unsigned + symmetric bias
|
||||||
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||||
# to add `scalar_types.float8_e4m3fn` here
|
if include_fp_type:
|
||||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
res += [scalar_types.float8_e4m3fn]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _check_marlin_supported(
|
def _check_marlin_supported(
|
||||||
@ -62,7 +68,7 @@ def _check_marlin_supported(
|
|||||||
capability_tuple.to_int())
|
capability_tuple.to_int())
|
||||||
|
|
||||||
supported_types = query_marlin_supported_quant_types(
|
supported_types = query_marlin_supported_quant_types(
|
||||||
has_zp, device_capability)
|
has_zp, True, device_capability)
|
||||||
|
|
||||||
if quant_type not in supported_types:
|
if quant_type not in supported_types:
|
||||||
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
||||||
@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_make_workspace_new(device: torch.device,
|
||||||
|
max_blocks_per_sm: int = 1) -> torch.Tensor:
|
||||||
|
# In the new marlin kernel, we use the num of threadblocks as workspace
|
||||||
|
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
|
||||||
|
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||||
|
return torch.zeros(sms * max_blocks_per_sm,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||||
return (not act_order) or (act_order and not is_row_parallel)
|
return (not act_order) or (act_order and not is_row_parallel)
|
||||||
|
|
||||||
@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_warn_marlin_atomic_add(device, dtype):
|
||||||
|
if torch.compiler.is_dynamo_compiling():
|
||||||
|
return
|
||||||
|
device_capability = torch.cuda.get_device_capability(device)
|
||||||
|
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||||
|
logger.info_once(
|
||||||
|
"You are running Marlin kernel with bf16 on GPUs before SM90. "
|
||||||
|
"You can consider change to fp16 to achieve better performance "
|
||||||
|
"if possible.")
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_warn_marlin_atomic_add_env():
|
||||||
|
if torch.compiler.is_dynamo_compiling():
|
||||||
|
return
|
||||||
|
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||||
|
return
|
||||||
|
logger.info_once(
|
||||||
|
"Marlin kernel can achieve better performance for small size_n "
|
||||||
|
"with experimental use_atomic_add feature. "
|
||||||
|
"You can consider set environment variable "
|
||||||
|
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
|
||||||
|
|
||||||
|
|
||||||
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
||||||
dtype: torch.dtype) -> bool:
|
dtype: torch.dtype) -> bool:
|
||||||
|
|
||||||
|
# the performance of atomicAdd is better than global reduce
|
||||||
|
# only when m*n is small and k is large
|
||||||
|
if n >= 2048 or k < 2048 or device.type != "cuda":
|
||||||
|
return False
|
||||||
|
|
||||||
# disable atomicAdd reduce by default,
|
# disable atomicAdd reduce by default,
|
||||||
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
|
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||||
|
maybe_warn_marlin_atomic_add_env()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# sm8x doesn't support atomicAdd + bfloat16 natively
|
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||||
device_capability = torch.cuda.get_device_capability(device)
|
device_capability = torch.cuda.get_device_capability(device)
|
||||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||||
|
maybe_warn_marlin_atomic_add(device, dtype)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# the performance of atomicAdd is better than global reduce
|
return True
|
||||||
# only when m*n is small and k is large
|
|
||||||
return n < 2048 and k >= 2048
|
|
||||||
|
|
||||||
|
|
||||||
def apply_gptq_marlin_linear(
|
def apply_gptq_marlin_linear(
|
||||||
@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
|
|||||||
wtype: ScalarType,
|
wtype: ScalarType,
|
||||||
output_size_per_partition: int,
|
output_size_per_partition: int,
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
has_zp: bool,
|
|
||||||
is_k_full: bool,
|
is_k_full: bool,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||||
@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
|
|||||||
dtype=input.dtype)
|
dtype=input.dtype)
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||||
|
None,
|
||||||
weight,
|
weight,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
weight_zp,
|
weight_zp,
|
||||||
@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
|
|||||||
size_k=input_size_per_partition,
|
size_k=input_size_per_partition,
|
||||||
is_k_full=is_k_full,
|
is_k_full=is_k_full,
|
||||||
use_atomic_add=use_atomic_add,
|
use_atomic_add=use_atomic_add,
|
||||||
has_zp=has_zp,
|
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False)
|
is_zp_float=False)
|
||||||
|
|
||||||
@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
|
|||||||
dtype=input.dtype)
|
dtype=input.dtype)
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||||
|
None,
|
||||||
weight,
|
weight,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
weight_zp,
|
weight_zp,
|
||||||
@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
|
|||||||
size_m=reshaped_x.shape[0],
|
size_m=reshaped_x.shape[0],
|
||||||
size_n=output_size_per_partition,
|
size_n=output_size_per_partition,
|
||||||
size_k=input_size_per_partition,
|
size_k=input_size_per_partition,
|
||||||
is_k_full=True,
|
|
||||||
has_zp=True,
|
|
||||||
use_atomic_add=use_atomic_add,
|
use_atomic_add=use_atomic_add,
|
||||||
use_fp32_reduce=use_fp32_reduce,
|
use_fp32_reduce=use_fp32_reduce,
|
||||||
is_zp_float=False)
|
is_zp_float=False)
|
||||||
|
|||||||
@ -6,9 +6,11 @@ import torch
|
|||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
||||||
|
should_use_atomic_add_reduce)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -18,30 +20,40 @@ def is_fp8_marlin_supported():
|
|||||||
|
|
||||||
|
|
||||||
def apply_fp8_marlin_linear(
|
def apply_fp8_marlin_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
size_n: int,
|
size_n: int,
|
||||||
size_k: int,
|
size_k: int,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||||
# Marlin kernel for fast weight-only FP8 quantization
|
# Marlin kernel for fast weight-only FP8 quantization
|
||||||
|
|
||||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||||
out_shape = input.shape[:-1] + (size_n, )
|
out_shape = input.shape[:-1] + (size_n, )
|
||||||
|
|
||||||
output = ops.fp8_marlin_gemm(
|
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||||
a=reshaped_x,
|
n=size_n,
|
||||||
b_q_weight=weight,
|
k=size_k,
|
||||||
b_scales=weight_scale,
|
device=input.device,
|
||||||
workspace=workspace,
|
dtype=input.dtype)
|
||||||
num_bits=8,
|
|
||||||
size_m=reshaped_x.shape[0],
|
output = ops.gptq_marlin_gemm(a=reshaped_x,
|
||||||
size_n=size_n,
|
c=None,
|
||||||
size_k=size_k,
|
b_q_weight=weight,
|
||||||
)
|
b_scales=weight_scale,
|
||||||
|
b_zeros=None,
|
||||||
|
g_idx=None,
|
||||||
|
perm=None,
|
||||||
|
workspace=workspace,
|
||||||
|
b_q_type=scalar_types.float8_e4m3fn,
|
||||||
|
size_m=reshaped_x.size(0),
|
||||||
|
size_n=size_n,
|
||||||
|
size_k=size_k,
|
||||||
|
use_atomic_add=use_atomic_add,
|
||||||
|
use_fp32_reduce=use_fp32_reduce)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output.add_(bias) # In-place add
|
output.add_(bias) # In-place add
|
||||||
@ -50,7 +62,7 @@ def apply_fp8_marlin_linear(
|
|||||||
|
|
||||||
|
|
||||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||||
strategy: str = "tensor") -> None:
|
size_k_first: bool = True) -> None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Your GPU does not have native support for FP8 computation but "
|
"Your GPU does not have native support for FP8 computation but "
|
||||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||||
@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
part_size_n = layer.output_size_per_partition
|
part_size_n = layer.output_size_per_partition
|
||||||
part_size_k = layer.input_size_per_partition
|
part_size_k = layer.input_size_per_partition
|
||||||
|
|
||||||
|
if size_k_first:
|
||||||
|
assert layer.weight.shape == (part_size_k, part_size_n)
|
||||||
|
else:
|
||||||
|
assert layer.weight.shape == (part_size_n, part_size_k)
|
||||||
|
|
||||||
device = layer.weight.device
|
device = layer.weight.device
|
||||||
|
|
||||||
# WORKSPACE
|
# WORKSPACE
|
||||||
layer.workspace = marlin_make_workspace(part_size_n, device)
|
layer.workspace = marlin_make_workspace_new(device)
|
||||||
|
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
# Repack weights to marlin format
|
# Repack weights to marlin format
|
||||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
|
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||||
layer.weight),
|
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
|
||||||
perm=torch.empty(0,
|
if not size_k_first:
|
||||||
dtype=torch.int,
|
qweight = qweight.T.contiguous()
|
||||||
device=device),
|
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||||
|
perm=perm,
|
||||||
size_k=part_size_k,
|
size_k=part_size_k,
|
||||||
size_n=part_size_n,
|
size_n=part_size_n,
|
||||||
num_bits=8)
|
num_bits=8)
|
||||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||||
|
|
||||||
# WEIGHT SCALES
|
# WEIGHT SCALES
|
||||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
|
||||||
# Permute scales
|
# Permute scales
|
||||||
|
if "weight_scale" in dir(layer):
|
||||||
|
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||||
|
elif "weight_scale_inv" in dir(layer):
|
||||||
|
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||||
|
del layer.weight_scale_inv
|
||||||
|
|
||||||
|
if layer.weight_block_size is None:
|
||||||
|
group_size = -1
|
||||||
|
else:
|
||||||
|
group_size = layer.weight_block_size[1]
|
||||||
|
|
||||||
|
# marlin kernel only support channel-wise and group-wise quantization
|
||||||
|
# we need to convert the scales
|
||||||
|
if layer.weight_block_size is None:
|
||||||
|
if scales.nelement() == 1:
|
||||||
|
# tensor-wise quantization -> channel-wise quantization
|
||||||
|
# (1, 1) =>(repeat)=> (1, size_n)
|
||||||
|
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
|
||||||
|
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
|
||||||
|
assert part_size_n % scales.nelement() == 0
|
||||||
|
s_size = scales.nelement()
|
||||||
|
# tensor-wise quantization (for gate-up proj)
|
||||||
|
# -> channel-wise quantization
|
||||||
|
# (1, s_size) =>(repeat)=> (1, size_n)
|
||||||
|
scales = scales.view(1, s_size)
|
||||||
|
scales = scales.repeat_interleave(part_size_n // s_size, 1)
|
||||||
|
else:
|
||||||
|
# channel-wise quantization
|
||||||
|
# (1, size_n)
|
||||||
|
scales = scales.view(1, part_size_n)
|
||||||
|
else:
|
||||||
|
# block-wise quantization -> group-wise quantization
|
||||||
|
# (size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||||
|
# =>(repeat)=> (size_k // block_size[1], size_n)
|
||||||
|
block_n = layer.weight_block_size[0]
|
||||||
|
scales = scales.T.repeat_interleave(block_n, 1)
|
||||||
|
# size_n may not divisible by block_size[0]
|
||||||
|
scales = scales[:, :part_size_n]
|
||||||
|
|
||||||
marlin_scales = marlin_permute_scales(s=scales,
|
marlin_scales = marlin_permute_scales(s=scales,
|
||||||
size_k=part_size_k,
|
size_k=part_size_k,
|
||||||
size_n=part_size_n,
|
size_n=part_size_n,
|
||||||
group_size=-1)
|
group_size=group_size)
|
||||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||||
|
size_k_first: bool = True) -> None:
|
||||||
|
logger.warning_once(
|
||||||
|
"Your GPU does not have native support for FP8 computation but "
|
||||||
|
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||||
|
"be used leveraging the Marlin kernel. This may degrade "
|
||||||
|
"performance for compute-heavy workloads.")
|
||||||
|
|
||||||
|
e = layer.num_experts
|
||||||
|
k = layer.hidden_size
|
||||||
|
n = layer.intermediate_size_per_partition
|
||||||
|
|
||||||
|
# WORKSPACE
|
||||||
|
device = layer.w13_weight.device
|
||||||
|
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||||
|
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
# Repack weights to marlin format
|
||||||
|
for name in ["w13_weight", "w2_weight"]:
|
||||||
|
weight = getattr(layer, name)
|
||||||
|
tensor_list = []
|
||||||
|
if "w13" in name:
|
||||||
|
size_n, size_k = n * 2, k
|
||||||
|
else:
|
||||||
|
size_n, size_k = k, n
|
||||||
|
|
||||||
|
if size_k_first:
|
||||||
|
assert weight.shape == (e, size_k, size_n)
|
||||||
|
else:
|
||||||
|
assert weight.shape == (e, size_n, size_k)
|
||||||
|
|
||||||
|
for i in range(e):
|
||||||
|
qweight = pack_fp8_to_int32(weight[i], size_k_first)
|
||||||
|
if not size_k_first:
|
||||||
|
qweight = qweight.T.contiguous()
|
||||||
|
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight,
|
||||||
|
perm=perm,
|
||||||
|
size_k=size_k,
|
||||||
|
size_n=size_n,
|
||||||
|
num_bits=8)
|
||||||
|
tensor_list.append(marlin_qweight)
|
||||||
|
|
||||||
|
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||||
|
weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
setattr(layer, name, weight)
|
||||||
|
|
||||||
|
# WEIGHT SCALES
|
||||||
|
# Permute scales
|
||||||
|
if layer.weight_block_size is None:
|
||||||
|
group_size = -1
|
||||||
|
else:
|
||||||
|
group_size = layer.weight_block_size[1]
|
||||||
|
|
||||||
|
for name in ["w13", "w2"]:
|
||||||
|
if name + "_weight_scale" in dir(layer):
|
||||||
|
new_name = name + "_weight_scale"
|
||||||
|
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||||
|
delattr(layer, new_name)
|
||||||
|
elif name + "_weight_scale_inv" in dir(layer):
|
||||||
|
new_name = name + "_weight_scale_inv"
|
||||||
|
scales = getattr(layer, new_name).to(layer.orig_dtype)
|
||||||
|
delattr(layer, new_name)
|
||||||
|
|
||||||
|
tensor_list = []
|
||||||
|
if "w13" in name:
|
||||||
|
size_n, size_k = n * 2, k
|
||||||
|
else:
|
||||||
|
size_n, size_k = k, n
|
||||||
|
|
||||||
|
# marlin kernel only support channel-wise and group-wise quantization
|
||||||
|
# we need to convert the scales
|
||||||
|
if layer.weight_block_size is None:
|
||||||
|
if scales.nelement() == e:
|
||||||
|
# tensor-wise quantization -> channel-wise quantization
|
||||||
|
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
|
||||||
|
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
|
||||||
|
elif scales.nelement() > e and scales.nelement() != e * size_n:
|
||||||
|
assert (e * size_n) % scales.nelement() == 0
|
||||||
|
s_size = scales.nelement() // e
|
||||||
|
# tensor-wise quantization (for gate-up proj)
|
||||||
|
# -> channel-wise quantization
|
||||||
|
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
|
||||||
|
scales = scales.view(e, 1, s_size)
|
||||||
|
scales = scales.repeat_interleave(size_n // s_size, 2)
|
||||||
|
else:
|
||||||
|
# channel-wise quantization
|
||||||
|
# (e, 1, size_n)
|
||||||
|
scales = scales.view(e, 1, size_n)
|
||||||
|
else:
|
||||||
|
# block-wise quantization -> group-wise quantization
|
||||||
|
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
|
||||||
|
# =>(repeat)=> (e, size_k // block_size[1], size_n)
|
||||||
|
block_n = layer.weight_block_size[0]
|
||||||
|
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
|
||||||
|
# size_n may not divisible by block_size[0]
|
||||||
|
scales = scales[..., :size_n].contiguous()
|
||||||
|
|
||||||
|
for i in range(e):
|
||||||
|
marlin_scales = marlin_permute_scales(s=scales[i],
|
||||||
|
size_k=size_k,
|
||||||
|
size_n=size_n,
|
||||||
|
group_size=group_size)
|
||||||
|
tensor_list.append(marlin_scales)
|
||||||
|
|
||||||
|
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||||
|
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||||
|
|
||||||
|
setattr(layer, name + "_weight_scale", scales)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
|
||||||
|
size_k_first: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Repack FP8 weights to gptq format (packed int32 elements)
|
Repack FP8 weights to gptq format (packed int32 elements)
|
||||||
"""
|
"""
|
||||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||||
assert fp8_tensor.shape[0] % 4 == 0
|
assert fp8_tensor.ndim == 2
|
||||||
|
|
||||||
# Reshape to prepare for packing
|
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
|
||||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
fp8_tensor = fp8_tensor.contiguous()
|
||||||
|
# fp8_tensor is contiguous and have shape (N, K) now
|
||||||
|
# with `.view(torch.int32)`, it become (N, K // 4)
|
||||||
|
int32_tensor = fp8_tensor.view(torch.int32)
|
||||||
|
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
|
||||||
|
|
||||||
# Convert fp8 to uint8 (byte) representation
|
|
||||||
byte_tensor = reshaped.view(torch.uint8)
|
|
||||||
|
|
||||||
# Pack 4 uint8 values into one int32
|
def marlin_quant_fp8_torch(weight, group_size):
|
||||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
size_n, size_k = weight.shape
|
||||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
device = weight.device
|
||||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
|
||||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
|
||||||
|
|
||||||
return packed.view(fp8_tensor.shape[0] // 4,
|
if group_size != -1:
|
||||||
*fp8_tensor.shape[1:]).contiguous()
|
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
|
||||||
|
repeated_scales = scales.repeat_interleave(group_size, 1)
|
||||||
|
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
||||||
|
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||||
|
else:
|
||||||
|
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
|
||||||
|
repeated_scales = scales.repeat_interleave(size_k, 1)
|
||||||
|
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
|
||||||
|
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
||||||
|
|
||||||
|
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(
|
||||||
|
b_q_weight=packed_weight,
|
||||||
|
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||||
|
size_k=size_k,
|
||||||
|
size_n=size_n,
|
||||||
|
num_bits=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
marlin_scales = marlin_permute_scales(s=scales.T,
|
||||||
|
size_k=size_k,
|
||||||
|
size_n=size_n,
|
||||||
|
group_size=group_size)
|
||||||
|
|
||||||
|
return weight_ref.T, marlin_qweight, marlin_scales
|
||||||
|
|||||||
@ -6,6 +6,8 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
_SCALAR_TYPES_ID_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
# Mirrors enum in `core/scalar_type.hpp`
|
# Mirrors enum in `core/scalar_type.hpp`
|
||||||
class NanRepr(Enum):
|
class NanRepr(Enum):
|
||||||
@ -158,6 +160,8 @@ class ScalarType:
|
|||||||
assert offset <= 64, \
|
assert offset <= 64, \
|
||||||
f"ScalarType fields too big {offset} to fit into an int64"
|
f"ScalarType fields too big {offset} to fit into an int64"
|
||||||
|
|
||||||
|
_SCALAR_TYPES_ID_MAP[val] = self
|
||||||
|
|
||||||
return val
|
return val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -295,6 +299,13 @@ class ScalarType:
|
|||||||
ret.id # noqa B018: make sure the id is cached
|
ret.id # noqa B018: make sure the id is cached
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_id(cls, scalar_type_id: int):
|
||||||
|
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||||
|
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||||
|
|
||||||
|
|
||||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||||
# for floating point types (leading f) the scheme is:
|
# for floating point types (leading f) the scheme is:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user