mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 19:47:06 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
248 lines
9.3 KiB
Plaintext
248 lines
9.3 KiB
Plaintext
#include "common.cuh"
|
|
#include "dispatch_utils.h"
|
|
#include "../../cub_helpers.h"
|
|
#include "../vectorization_utils.cuh"
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <ATen/cuda/Exceptions.h>
|
|
|
|
namespace vllm {
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void scaled_fp8_quant_kernel_strided(
|
|
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
|
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
|
int64_t out_row_stride) {
|
|
const int64_t token_idx = blockIdx.x; // one token per block
|
|
const int tid = threadIdx.x;
|
|
|
|
const scalar_t* token_in = input + token_idx * in_row_stride;
|
|
fp8_type* token_out = out + token_idx * out_row_stride;
|
|
|
|
const float inv_scale = 1.0f / (*scale);
|
|
|
|
vectorize_with_alignment<16>(
|
|
token_in, token_out, hidden_size, tid, blockDim.x,
|
|
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
|
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
|
inv_scale);
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void segmented_max_reduction_strided(
|
|
float* __restrict__ scale, const scalar_t* __restrict__ input,
|
|
int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
|
|
__shared__ float cache[256];
|
|
const int tid = threadIdx.x;
|
|
int64_t token_idx = blockIdx.x;
|
|
|
|
// one block per token. Guard in case gridDim.x > num_tokens.
|
|
if (token_idx >= num_tokens) {
|
|
return;
|
|
}
|
|
|
|
const scalar_t* row_ptr = input + token_idx * in_row_stride;
|
|
|
|
// each thread scans elements of the row in a strided fashion.
|
|
float thread_max = 0.0f;
|
|
for (int e = tid; e < hidden_size; e += blockDim.x) {
|
|
float v = fabsf(static_cast<float>(row_ptr[e]));
|
|
thread_max = fmaxf(thread_max, v);
|
|
}
|
|
|
|
cache[tid] = thread_max;
|
|
__syncthreads();
|
|
|
|
// parallel reduction to find row max.
|
|
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
|
|
if (tid < offset) {
|
|
cache[tid] = fmaxf(cache[tid], cache[tid + offset]);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// thread 0 updates global scale (per-tensor) atomically.
|
|
if (tid == 0) {
|
|
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
|
}
|
|
}
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void scaled_fp8_quant_kernel_strided_dynamic(
|
|
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
|
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
|
int64_t out_row_stride) {
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int tid = threadIdx.x;
|
|
|
|
const scalar_t* token_in = input + token_idx * in_row_stride;
|
|
fp8_type* token_out = out + token_idx * out_row_stride;
|
|
|
|
const float reciprocal_scale = 1.0f / (*scale);
|
|
vectorize_with_alignment<16>(
|
|
token_in, token_out, hidden_size, tid, blockDim.x,
|
|
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
|
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
|
reciprocal_scale);
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename fp8_type>
|
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
|
fp8_type* __restrict__ out, float* __restrict__ scale,
|
|
const scalar_t* __restrict__ input, const float* __restrict__ scale_ub,
|
|
int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int tid = threadIdx.x;
|
|
|
|
// Use int64 to avoid overflowing an int32 when calculating this offset
|
|
int64_t in_offset = static_cast<int64_t>(token_idx) * in_row_stride;
|
|
int64_t out_offset = static_cast<int64_t>(token_idx) * out_row_stride;
|
|
const scalar_t* token_in = input + in_offset;
|
|
fp8_type* token_out = out + out_offset;
|
|
|
|
// 1) per-token absmax
|
|
float absmax_val = 0.f;
|
|
vectorize_read_with_alignment<16>(
|
|
token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) {
|
|
absmax_val = fmaxf(absmax_val, fabsf(static_cast<float>(v)));
|
|
});
|
|
|
|
using BlockReduce = cub::BlockReduce<float, 256>;
|
|
__shared__ typename BlockReduce::TempStorage tmp;
|
|
const float block_max =
|
|
BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x);
|
|
|
|
__shared__ float token_scale;
|
|
if (tid == 0) {
|
|
token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max;
|
|
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
|
min_scaling_factor<fp8_type>::val());
|
|
scale[token_idx] = token_scale;
|
|
}
|
|
__syncthreads();
|
|
|
|
// 2) quantize
|
|
vectorize_with_alignment<16>(
|
|
token_in, token_out, hidden_size, tid, blockDim.x,
|
|
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
|
dst = scaled_fp8_conversion<false, fp8_type>(static_cast<float>(src),
|
|
token_scale);
|
|
});
|
|
}
|
|
|
|
} // namespace vllm
|
|
|
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
torch::Tensor const& input, // [..., d]
|
|
torch::Tensor const& scale) // [1]
|
|
{
|
|
TORCH_CHECK(input.stride(-1) == 1,
|
|
"last dimension of input must be contiguous");
|
|
TORCH_CHECK(out.stride(-1) == 1,
|
|
"last dimension of output must be contiguous");
|
|
|
|
const int hidden_size = input.size(-1);
|
|
const int num_tokens = input.numel() / hidden_size;
|
|
const int block_size = 256;
|
|
dim3 grid(num_tokens);
|
|
dim3 block(block_size);
|
|
|
|
const int64_t in_row_stride = input.stride(-2);
|
|
const int64_t out_row_stride = out.stride(-2);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
|
VLLM_DISPATCH_FP8_TYPES(
|
|
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
|
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
|
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
|
out_row_stride);
|
|
});
|
|
});
|
|
}
|
|
|
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
torch::Tensor const& input, // [..., d]
|
|
torch::Tensor& scale) // [1]
|
|
{
|
|
TORCH_CHECK(input.stride(-1) == 1,
|
|
"last dimension of input must be contiguous");
|
|
TORCH_CHECK(out.stride(-1) == 1,
|
|
"last dimension of output must be contiguous");
|
|
|
|
const int hidden_size = input.size(-1);
|
|
const int num_tokens = input.numel() / hidden_size;
|
|
const int block_size = 256;
|
|
dim3 grid(num_tokens);
|
|
dim3 block(block_size);
|
|
|
|
const int64_t in_row_stride = input.stride(-2);
|
|
const int64_t out_row_stride = out.stride(-2);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
// scale tensor should be initialised to <=0 before reduction
|
|
AT_CUDA_CHECK(
|
|
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
|
VLLM_DISPATCH_FP8_TYPES(
|
|
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
|
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
|
|
hidden_size, in_row_stride,
|
|
static_cast<int64_t>(num_tokens));
|
|
|
|
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
|
|
<<<grid, block, 0, stream>>>(
|
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
|
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
|
out_row_stride);
|
|
});
|
|
});
|
|
}
|
|
|
|
void dynamic_per_token_scaled_fp8_quant(
|
|
torch::Tensor& out, // [..., d]
|
|
torch::Tensor const& input, // [..., d]
|
|
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
|
TORCH_CHECK(input.stride(-1) == 1,
|
|
"last dimension of input must be contiguous");
|
|
TORCH_CHECK(out.stride(-1) == 1,
|
|
"last dimension of output must be contiguous");
|
|
|
|
const int hidden_size = input.size(-1);
|
|
const int num_tokens = input.numel() / hidden_size;
|
|
const int block_size = 256;
|
|
dim3 grid(num_tokens);
|
|
dim3 block(std::min(hidden_size, block_size));
|
|
|
|
const int64_t in_row_stride = input.stride(-2);
|
|
const int64_t out_row_stride = out.stride(-2);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
input.scalar_type(),
|
|
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
|
VLLM_DISPATCH_FP8_TYPES(
|
|
out.scalar_type(),
|
|
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
|
|
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
|
|
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
|
input.data_ptr<scalar_t>(),
|
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
|
hidden_size, in_row_stride, out_row_stride);
|
|
});
|
|
});
|
|
}
|