Jinzhen Lin 33c63e9547
[Kernel] [Quantization] Add MXFP4 and bias support for marlin kernel (#22428)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Huzaifa Sidhpurwala <huzaifas@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Animesh Jain <anijain@umich.edu>
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Roger Wang <hey@rogerw.me>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: yan <yan.ma@intel.com>
Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Xiao Liu <xiszishu@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Andy Xie <andy.xning@gmail.com>
Signed-off-by: Haibin Lin <haibin.lin@bytedance.com>
Signed-off-by: David Ben-David <davidb@pliops.com>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Signed-off-by: Abirdcfly <fp544037857@gmail.com>
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: huangweixiao <huangweixiao@msh.team>
Signed-off-by: alyosha-swamy <raghav@arcee.ai>
Signed-off-by: Eric Hanley <ericehanley@google.com>
Signed-off-by: Abatom <abzhonghua@gmail.com>
Signed-off-by: CLFutureX <775523362@qq.com>
Signed-off-by: Linkun Chen <github@lkchen.net>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: tlipoca9 <tlipoca9@gmail.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: isotr0py <2037008807@qq.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Zhang Jason <ning.zhang2@amd.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: asafg <asafg@ai21.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Lain <fusiyuan2000@hotmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: QscQ <qscqesze@gmail.com>
Signed-off-by: qingjun <qingjun@minimaxi.com>
Signed-off-by: Syed Muhammad Bin Asif <syedmba7@connect.hku.hk>
Signed-off-by: Lionel Villard <villard@us.ibm.com>
Signed-off-by: ycyaw66 <497410282@qq.com>
Signed-off-by: David Chen <530634352@qq.com>
Signed-off-by: Linkun <github@lkchen.net>
Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai>
Signed-off-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com>
Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
Signed-off-by: Andrew Chan <andrewkchan.akc@gmail.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: XIn Li <xinli@nvidia.com>
Signed-off-by: Junhao Li <junhao@ubicloud.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com>
Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com>
Signed-off-by: <zyy1102000@gmail.com>
Signed-off-by: Guy Stone <guys@spotify.com>
Signed-off-by: <yyweiss@gmail.com>
Signed-off-by: yyw <yyweiss@gmail.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com>
Signed-off-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com>
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: Huzaifa Sidhpurwala <huzaifas@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Animesh Jain <jainanimesh2305@yahoo.com>
Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com>
Co-authored-by: XiongfeiWei <isaacwxf23@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: JartX <sagformas@gmail.com>
Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: kf <kuanfu.liu@embeddedllm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Roger Wang <hey@rogerw.me>
Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Yuxuan Zhang <2448370773@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Xiao <xiszishu@gmail.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: Ning Xie <andy.xning@gmail.com>
Co-authored-by: H <linhaibin.eric@gmail.com>
Co-authored-by: David Ben-David <sdavidbd@gmail.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: TankNee <nee@tanknee.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com>
Co-authored-by: ZiTian.Zhao <zitian.zhao@tencentmusic.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Abirdcfly <fp544037857@gmail.com>
Co-authored-by: Giancarlo Delfin <32987265+TheEpicDolphin@users.noreply.github.com>
Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu>
Co-authored-by: Chenxi Yang <cxyang@meta.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Weixiao Huang <hwx.simle@gmail.com>
Co-authored-by: Raghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com>
Co-authored-by: ericehanley <ericehanley@google.com>
Co-authored-by: Zhonghua Deng <abzhonghua@gmail.com>
Co-authored-by: Po-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com>
Co-authored-by: PiteXChen <44110731+CLFutureX@users.noreply.github.com>
Co-authored-by: lkchen <github@lkchen.net>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: tlipoca9 <160737620+tlipoca9@users.noreply.github.com>
Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Zhang Jason <ning.zhang2@amd.com>
Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: asafg <asafg@ai21.com>
Co-authored-by: Lain <siyuanf@nvidia.com>
Co-authored-by: tc-mb <157115220+tc-mb@users.noreply.github.com>
Co-authored-by: imning3 <hbning@pku.edu.cn>
Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com>
Co-authored-by: qscqesze <qingjun@minimaxi.com>
Co-authored-by: Syed Muhammad Bin Asif <92625830+syedmba@users.noreply.github.com>
Co-authored-by: Lionel Villard <villard@us.ibm.com>
Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
Co-authored-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: Ming Yang <minos.future@gmail.com>
Co-authored-by: Adrián García García <adrigarvk8@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
Co-authored-by: JaceyShao <65159281+JaceyShao@users.noreply.github.com>
Co-authored-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com>
Co-authored-by: Ricardo Decal <crypdick@users.noreply.github.com>
Co-authored-by: Andrew Chan <andrewkchan.akc@gmail.com>
Co-authored-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
Co-authored-by: Zhiyu <zhiyuc@nvidia.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: XIn Li <xinli@nvidia.com>
Co-authored-by: Junhao Li <streaver91@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: iAmir97 <71513472+iAmir97@users.noreply.github.com>
Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com>
Co-authored-by: Hong Hanh <hanh.usth@gmail.com>
Co-authored-by: Daniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Guy Stone <guys@spotify.com>
Co-authored-by: yyweiss <70619747+yyweiss@users.noreply.github.com>
Co-authored-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com>
Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
2025-08-14 11:23:22 -07:00

523 lines
19 KiB
C++

/*
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
The process of fast dequantization can be summarized as a combination
of bitwise operations and floating-point computations:
weight =>(bit_op / bitwise operations)=>
f16_value =>(flop / floating-point computation)=>
dequantized_weight
Since the dequantized weights typically require subtracting the zero point and
applying a scale factor, the floating-point computation step can be fused with
the zero-point subtraction and scaling operations.
The following are the parts that need to be modified for the fused operation
of zero-point subtraction and scaling.
## INT4 => FP16/BF16 or INT8 => FP16
The floating-point computation is `__hsub2`
If has zero points:
flop(bit_op(weight)) - flop(bit_op(zp))
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
= bit_op(weight) - bit_op(zp)
so we don't need additional modification.
If has float zero points:
flop(bit_op(weight)) - fzp
= sub(bit_op(weight), bias) - fzp
= bit_op(weight) - (fzp + bias)
where the `fzp + bias` can be computed at weight loading. But this
may have accuracy issue, so we should not use this in most cases.
If has not zero points:
scale(flop(bit_op(weight)))
= scale(sub(bit_op(weight), bias))
= scale(bit_op(weight)) - scale(bias)
= fma(bit_op(weight), scale_factor, scale(bias))
where the `scale(bias)` can be cached. But this may have accuracy issue,
so we should not use this in most cases.
## INT8 => BF16
INT8 => BF16 is a special case, it use byte_perm instead of flop.
We cannot fused byte_perm with scaling.
## FP4/FP8 => FP16/BF16
scale(flop(bit_op(weight)))
= scale(mul(bit_op(weight), multiplier))
= mul(bit_op(weight), scale_factor * multiplier)
where `scale_factor * multiplier` can be computed at weight loading.
*/
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
half2* frag_b) {
const int MASK = 0x000f000f;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
half2* frag_b) {
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43084308;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43004300;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(frag_b[0],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
half2* frag_b) {
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
half2* frag_b) {
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(frag_b[0],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1],
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
int q, half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
half2* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
int q, half2* frag_b) {
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
int q, nv_bfloat162* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and BF16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
int Out2 = (q & 0xFF00FF00) >> 1;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
int q, nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
int q, nv_bfloat162* frag_b) {
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
// but we assume that such a extreme value would not occur in real models.
int Out1 = (q & 0xFF00FF00) >> 1;
q <<= 7;
int Out2 = q & 0x7F807F80;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME