mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 13:42:18 +08:00
[Build] Allow shipping PTX on a per-file basis (#18155)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
8795eb9975
commit
c7852a6d9b
@ -301,7 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||||
# are not supported by Machete yet.
|
# are not supported by Machete yet.
|
||||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
# 9.0 for latest bf16 atomicAdd PTX
|
||||||
|
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_ARCHS)
|
if (MARLIN_ARCHS)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -445,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
#
|
#
|
||||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||||
# kernels for the remaining archs that are not already built for 3x.
|
# kernels for the remaining archs that are not already built for 3x.
|
||||||
|
# (Build 8.9 for FP8)
|
||||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||||
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
|
||||||
# subtract out the archs that are already built for 3x
|
# subtract out the archs that are already built for 3x
|
||||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||||
if (SCALED_MM_2X_ARCHS)
|
if (SCALED_MM_2X_ARCHS)
|
||||||
@ -675,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||||
|
|
||||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
# 9.0 for latest bf16 atomicAdd PTX
|
||||||
|
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_MOE_ARCHS)
|
if (MARLIN_MOE_ARCHS)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
|
|||||||
"${multiValueArgs}" ${ARGN} )
|
"${multiValueArgs}" ${ARGN} )
|
||||||
|
|
||||||
foreach(_ARCH ${arg_CUDA_ARCHS})
|
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
# handle +PTX suffix: generate both sm and ptx codes if requested
|
||||||
set_gencode_flag_for_srcs(
|
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
||||||
SRCS ${arg_SRCS}
|
if(NOT _HAS_PTX EQUAL -1)
|
||||||
ARCH "compute_${_ARCH}"
|
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
||||||
CODE "sm_${_ARCH}")
|
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
||||||
|
set_gencode_flag_for_srcs(
|
||||||
|
SRCS ${arg_SRCS}
|
||||||
|
ARCH "compute_${_STRIPPED_ARCH}"
|
||||||
|
CODE "sm_${_STRIPPED_ARCH}")
|
||||||
|
set_gencode_flag_for_srcs(
|
||||||
|
SRCS ${arg_SRCS}
|
||||||
|
ARCH "compute_${_STRIPPED_ARCH}"
|
||||||
|
CODE "compute_${_STRIPPED_ARCH}")
|
||||||
|
else()
|
||||||
|
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
||||||
|
set_gencode_flag_for_srcs(
|
||||||
|
SRCS ${arg_SRCS}
|
||||||
|
ARCH "compute_${_STRIPPED_ARCH}"
|
||||||
|
CODE "sm_${_STRIPPED_ARCH}")
|
||||||
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
if (${arg_BUILD_PTX_FOR_ARCH})
|
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||||
@ -251,7 +266,10 @@ endmacro()
|
|||||||
#
|
#
|
||||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||||
# `TGT_CUDA_ARCHS` list of gencodes.
|
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||||
|
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||||
|
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||||
|
# architecture in `SRC_CUDA_ARCHS`.
|
||||||
# The loose intersection is defined as:
|
# The loose intersection is defined as:
|
||||||
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||||
# where `<=` is the version comparison operator.
|
# where `<=` is the version comparison operator.
|
||||||
@ -268,44 +286,63 @@ endmacro()
|
|||||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||||
#
|
#
|
||||||
|
# Example With PTX:
|
||||||
|
# SRC_CUDA_ARCHS="8.0+PTX"
|
||||||
|
# TGT_CUDA_ARCHS="9.0"
|
||||||
|
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
|
# OUT_CUDA_ARCHS="8.0+PTX"
|
||||||
|
#
|
||||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
||||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
||||||
|
|
||||||
|
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
||||||
|
set(_PTX_ARCHS)
|
||||||
|
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||||
|
if(_arch MATCHES "\\+PTX$")
|
||||||
|
string(REPLACE "+PTX" "" _base "${_arch}")
|
||||||
|
list(APPEND _PTX_ARCHS "${_base}")
|
||||||
|
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||||
|
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
||||||
|
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
||||||
|
|
||||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||||
set(_CUDA_ARCHS)
|
set(_CUDA_ARCHS)
|
||||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||||
set(_CUDA_ARCHS "9.0a")
|
set(_CUDA_ARCHS "9.0a")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
|
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
|
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
|
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||||
set(_CUDA_ARCHS "10.0a")
|
set(_CUDA_ARCHS "10.0a")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
|
||||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||||
# compatibility is only forward compatible within the same major version).
|
# compatibility is only forward compatible within the same major version).
|
||||||
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
||||||
set(_TMP_ARCH)
|
set(_TMP_ARCH)
|
||||||
# Extract the major version of the target arch
|
# Extract the major version of the target arch
|
||||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
||||||
# Extract the major version of the source arch
|
# Extract the major version of the source arch
|
||||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||||
# Check major-version match AND version-less-or-equal
|
# Check version-less-or-equal, and allow PTX arches to match across majors
|
||||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||||
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
|||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||||
|
|
||||||
|
# reapply +PTX suffix to architectures that requested PTX
|
||||||
|
set(_FINAL_ARCHS)
|
||||||
|
foreach(_arch ${_CUDA_ARCHS})
|
||||||
|
if(_arch IN_LIST _PTX_ARCHS)
|
||||||
|
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
||||||
|
else()
|
||||||
|
list(APPEND _FINAL_ARCHS "${_arch}")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||||
|
|
||||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user