Fix cuda_archs_loose_intersection when handling sm_*a (#20207)

Signed-off-by: Huy Do <huydhn@gmail.com>
This commit is contained in:
Huy Do 2025-06-29 16:52:34 -07:00 committed by GitHub
parent 6f2f53a82d
commit 6c9837a761
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 21 deletions

View File

@ -562,7 +562,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"if you intend on running FP8 quantized MoE models on Hopper.") "if you intend on running FP8 quantized MoE models on Hopper.")
else() else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found " message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures") "in CUDA target architectures.")
endif() endif()
endif() endif()
@ -574,7 +574,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_EXT_SRC "${SRCS}")
endif() message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
message(STATUS "Not building moe_data as CUDA Compiler version is "
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
else()
message(STATUS "Not building moe_data as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# #
# Machete kernels # Machete kernels

View File

@ -265,8 +265,8 @@ macro(set_gencode_flags_for_srcs)
endmacro() endmacro()
# #
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the # `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
@ -278,7 +278,7 @@ endmacro()
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). # x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
# The result is stored in `OUT_CUDA_ARCHS`. # The result is stored in `OUT_CUDA_ARCHS`.
# #
# Example: # Example:
@ -313,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS) set(_CUDA_ARCHS)
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) foreach(_arch ${_SRC_CUDA_ARCHS})
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") if(_arch MATCHES "\\a$")
if ("9.0" IN_LIST TGT_CUDA_ARCHS) list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") string(REPLACE "a" "" _base "${_arch}")
set(_CUDA_ARCHS "9.0a") if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
list(APPEND _CUDA_ARCHS "${_arch}")
endif()
endif() endif()
endif() endforeach()
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
set(_CUDA_ARCHS "10.0a")
endif()
endif()
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
@ -359,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach() endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS) list(REMOVE_DUPLICATES _CUDA_ARCHS)
# reapply +PTX suffix to architectures that requested PTX # reapply +PTX suffix to architectures that requested PTX
set(_FINAL_ARCHS) set(_FINAL_ARCHS)
foreach(_arch ${_CUDA_ARCHS}) foreach(_arch ${_CUDA_ARCHS})
@ -370,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endif() endif()
endforeach() endforeach()
set(_CUDA_ARCHS ${_FINAL_ARCHS}) set(_CUDA_ARCHS ${_FINAL_ARCHS})
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction() endfunction()