mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 19:47:06 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
369 lines
15 KiB
Plaintext
369 lines
15 KiB
Plaintext
/*
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <torch/all.h>
|
|
|
|
#include <cuda_runtime_api.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#include <cuda_fp8.h>
|
|
#include "dispatch_utils.h"
|
|
|
|
#include "nvfp4_utils.cuh"
|
|
#include "launch_bounds_utils.h"
|
|
|
|
namespace vllm {
|
|
|
|
// Use UE4M3 by default.
|
|
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
|
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
|
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
|
uint32_t* input_offset_by_experts,
|
|
uint32_t* output_scale_offset_by_experts, int n_experts,
|
|
bool low_latency) {
|
|
using PackedVec = PackedVec<Type>;
|
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
|
"Vec size is not matched.");
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
|
|
|
// Each global thread processes one element
|
|
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
|
globalIdx += gridDim.x * blockDim.x) {
|
|
// Calculate which row and column this global thread should process
|
|
int rowIdx = globalIdx / colsPerRow;
|
|
int colIdx = globalIdx % colsPerRow;
|
|
|
|
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
// Get the output tensor offset.
|
|
// Same as inOffset because 8 elements are packed into one uint32_t.
|
|
int64_t outOffset = inOffset;
|
|
auto& out_pos = out[outOffset];
|
|
|
|
// Find index within the experts using different strategies based on expert
|
|
// count
|
|
int rowIdx_in_expert = 0;
|
|
int expert_idx = 0;
|
|
|
|
if constexpr (SMALL_NUM_EXPERTS) {
|
|
for (int i = 0; i < n_experts; i++) {
|
|
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
|
|
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
|
|
if (rowIdx >= current_offset && rowIdx < next_offset) {
|
|
rowIdx_in_expert = rowIdx - current_offset;
|
|
expert_idx = i;
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
// Load input offsets into registers first, then do the computation.
|
|
// Local array size set to 17 because of register limit.
|
|
uint32_t local_offsets[17];
|
|
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
|
|
*reinterpret_cast<int4*>(local_offsets) =
|
|
__ldca(reinterpret_cast<const int4*>(
|
|
&input_offset_by_experts[chunk_start]));
|
|
*reinterpret_cast<int4*>(local_offsets + 4) =
|
|
__ldca(reinterpret_cast<const int4*>(
|
|
&input_offset_by_experts[chunk_start + 4]));
|
|
*reinterpret_cast<int4*>(local_offsets + 8) =
|
|
__ldca(reinterpret_cast<const int4*>(
|
|
&input_offset_by_experts[chunk_start + 8]));
|
|
*reinterpret_cast<int4*>(local_offsets + 12) =
|
|
__ldca(reinterpret_cast<const int4*>(
|
|
&input_offset_by_experts[chunk_start + 12]));
|
|
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
|
|
|
|
// Check against the 16 loaded offsets
|
|
#pragma unroll
|
|
for (int i = 0; i < 16; i++) {
|
|
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
|
|
rowIdx_in_expert = rowIdx - local_offsets[i];
|
|
expert_idx = chunk_start + i;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get the global scaling factor, which will be applied to the SF.
|
|
// Note SFScale is the same as next GEMM's alpha, which is
|
|
// (448.f / (Alpha_A / 6.f)).
|
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
|
|
|
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
// The actual output_scales dim is computed from the padded numCols.
|
|
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
|
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
|
uint32_t* SFout_in_expert =
|
|
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
|
|
|
auto sf_out =
|
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
|
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
|
|
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
|
}
|
|
}
|
|
|
|
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
|
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
|
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
|
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
|
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
|
uint32_t* input_offset_by_experts,
|
|
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
|
using PackedVec = PackedVec<Type>;
|
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
|
"Vec size is not matched.");
|
|
extern __shared__ uint32_t shared_input_offsets[];
|
|
|
|
// Load input offsets into shared memory.
|
|
// If n_experts is larger than 4, use vectorized int4 to save instructions.
|
|
// If n_experts is smaller than 4, read directly.
|
|
if constexpr (SMALL_NUM_EXPERTS) {
|
|
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
|
|
shared_input_offsets[i] = input_offset_by_experts[i];
|
|
}
|
|
} else {
|
|
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
|
|
*reinterpret_cast<int4*>(&shared_input_offsets[i]) =
|
|
*reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
|
|
}
|
|
if (threadIdx.x == 0) {
|
|
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
|
|
|
// Each global thread processes one element
|
|
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
|
globalIdx += gridDim.x * blockDim.x) {
|
|
// Calculate which row and column this global thread should process
|
|
int rowIdx = globalIdx / colsPerRow;
|
|
int colIdx = globalIdx % colsPerRow;
|
|
|
|
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
|
int64_t outOffset = inOffset;
|
|
auto& out_pos = out[outOffset];
|
|
|
|
// Find expert using binary search for better performance with large m_topk
|
|
int rowIdx_in_expert = 0;
|
|
int expert_idx = 0;
|
|
|
|
// Binary search through experts using shared memory
|
|
int left = 0, right = n_experts - 1;
|
|
while (left <= right) {
|
|
int mid = (left + right) / 2;
|
|
// Get offsets: shared_input_offsets[i] corresponds to
|
|
// input_offset_by_experts[i]
|
|
uint32_t mid_offset = shared_input_offsets[mid];
|
|
uint32_t next_offset = shared_input_offsets[mid + 1];
|
|
|
|
if (rowIdx >= mid_offset && rowIdx < next_offset) {
|
|
rowIdx_in_expert = rowIdx - mid_offset;
|
|
expert_idx = mid;
|
|
break;
|
|
} else if (rowIdx < mid_offset) {
|
|
right = mid - 1;
|
|
} else {
|
|
left = mid + 1;
|
|
}
|
|
}
|
|
|
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
|
|
|
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
|
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
|
uint32_t* SFout_in_expert =
|
|
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
|
|
|
auto sf_out =
|
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
|
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
|
|
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void quant_impl(void* output, void* output_scale, void* input,
|
|
void* input_global_scale, void* input_offset_by_experts,
|
|
void* output_scale_offset_by_experts, int m_topk, int k,
|
|
int n_experts, cudaStream_t stream) {
|
|
// TODO: this multiProcessorCount should be cached.
|
|
int device;
|
|
cudaGetDevice(&device);
|
|
int multiProcessorCount;
|
|
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
|
|
device);
|
|
|
|
// Grid, Block size.
|
|
// Each thread converts 8 values.
|
|
int const workSizePerRow = k / ELTS_PER_THREAD;
|
|
int const totalWorkSize = m_topk * workSizePerRow;
|
|
dim3 block(std::min(workSizePerRow, 512));
|
|
// Get number of blocks per SM
|
|
int const numBlocksPerSM =
|
|
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
|
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
|
|
multiProcessorCount * numBlocksPerSM));
|
|
while (grid.x <= multiProcessorCount && block.x > 64) {
|
|
grid.x *= 2;
|
|
block.x = (block.x + 1) / 2;
|
|
}
|
|
|
|
int const blockRepeat =
|
|
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
|
|
if (blockRepeat > 1) {
|
|
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
|
if (n_experts >= 4) {
|
|
cvt_fp16_to_fp4<T, false, false>
|
|
<<<grid, block, shared_mem_size, stream>>>(
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
reinterpret_cast<uint32_t*>(output),
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
n_experts);
|
|
} else {
|
|
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
reinterpret_cast<uint32_t*>(output),
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
n_experts);
|
|
}
|
|
} else {
|
|
if (n_experts >= 16) {
|
|
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
reinterpret_cast<uint32_t*>(output),
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
n_experts, /* bool low_latency */ true);
|
|
} else {
|
|
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
|
m_topk, k, reinterpret_cast<T*>(input),
|
|
reinterpret_cast<float*>(input_global_scale),
|
|
reinterpret_cast<uint32_t*>(output),
|
|
reinterpret_cast<uint32_t*>(output_scale),
|
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
|
n_experts, /* bool low_latency */ true);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace vllm
|
|
|
|
/*Quantization entry for fp4 experts quantization*/
|
|
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
|
#define CHECK_CONTIGUOUS(x, m) \
|
|
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
|
#define CHECK_INPUT(x, m) \
|
|
CHECK_TH_CUDA(x, m); \
|
|
CHECK_CONTIGUOUS(x, m);
|
|
|
|
constexpr auto HALF = at::ScalarType::Half;
|
|
constexpr auto BF16 = at::ScalarType::BFloat16;
|
|
constexpr auto FLOAT = at::ScalarType::Float;
|
|
constexpr auto INT = at::ScalarType::Int;
|
|
constexpr auto UINT8 = at::ScalarType::Byte;
|
|
|
|
void scaled_fp4_experts_quant_sm100a(
|
|
torch::Tensor& output, torch::Tensor& output_scale,
|
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
torch::Tensor const& input_offset_by_experts,
|
|
torch::Tensor const& output_scale_offset_by_experts) {
|
|
CHECK_INPUT(output, "output must be a CUDA tensor");
|
|
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
|
CHECK_INPUT(input, "input must be a CUDA tensor");
|
|
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
|
CHECK_INPUT(input_offset_by_experts,
|
|
"input_offset_by_experts must be a CUDA tensor");
|
|
CHECK_INPUT(output_scale_offset_by_experts,
|
|
"output_scale_offset_by_experts must be a CUDA tensor");
|
|
|
|
TORCH_CHECK(output.dim() == 2);
|
|
TORCH_CHECK(output_scale.dim() == 2);
|
|
TORCH_CHECK(input.dim() == 2);
|
|
TORCH_CHECK(input_global_scale.dim() == 1);
|
|
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
|
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
|
|
|
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
|
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
|
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
|
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
|
// output is uint8 (two nvfp4 values are packed into one uint8)
|
|
// output_scale is int32 (four fp8 values are packed into one int32)
|
|
TORCH_CHECK(output.scalar_type() == UINT8);
|
|
TORCH_CHECK(output_scale.scalar_type() == INT);
|
|
|
|
const int BLOCK_SIZE = 16;
|
|
auto m_topk = input.size(0);
|
|
auto k = input.size(1);
|
|
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
|
auto n_experts = input_global_scale.size(0);
|
|
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
|
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
|
TORCH_CHECK(output.size(0) == m_topk);
|
|
TORCH_CHECK(output.size(1) == k / 2);
|
|
int scales_k = k / BLOCK_SIZE;
|
|
// 4 means the swizzle requirement by nvidia nvfp4.
|
|
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
|
// 4 means 4 fp8 values are packed into one int32
|
|
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
const cudaStream_t stream =
|
|
at::cuda::getCurrentCUDAStream(input.get_device());
|
|
|
|
VLLM_DISPATCH_HALF_TYPES(
|
|
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
|
vllm::quant_impl<cuda_type>(
|
|
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
|
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
|
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
|
stream);
|
|
});
|
|
}
|