mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:57:12 +08:00
[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)
This commit is contained in:
parent
3dbb215b38
commit
aeb37c2a72
222
CMakeLists.txt
222
CMakeLists.txt
@ -143,6 +143,19 @@ else()
|
||||
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# For cuda we want to be able to control which architectures we compile for on
|
||||
# a per-file basis in order to cut down on compile time. So here we extract
|
||||
# the set of architectures we want to compile for and remove the from the
|
||||
# CMAKE_CUDA_FLAGS so that they are not applied globally.
|
||||
#
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
|
||||
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||
# the supported versions for the current language.
|
||||
@ -223,30 +236,89 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS})
|
||||
if (MARLIN_ARCHS)
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
#
|
||||
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
||||
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
|
||||
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||
set_source_files_properties(
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
else()
|
||||
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
|
||||
# build any 3x kernels
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||
else()
|
||||
if (SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||
" for and covered by scaled_mm_c3x")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
@ -254,47 +326,72 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Machete kernels
|
||||
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
)
|
||||
set(MACHETE_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py)
|
||||
file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH)
|
||||
|
||||
if (NOT machete_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Machete generation failed."
|
||||
" Result: \"${machete_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
||||
message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
)
|
||||
|
||||
if (NOT machete_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Machete generation failed."
|
||||
" Result: \"${machete_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
||||
else()
|
||||
set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run machete generate script hash" FORCE)
|
||||
message(STATUS "Machete generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Machete generation completed successfully.")
|
||||
message(STATUS "Machete generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
# Add machete generated sources
|
||||
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
|
||||
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
|
||||
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
|
||||
|
||||
set_source_files_properties(
|
||||
${MACHETE_GEN_SOURCES}
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
# forward compatible
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MACHETE_GEN_SOURCES}"
|
||||
CUDA_ARCHS "${MACHETE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
csrc/quantization/machete/machete_pytorch.cu)
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building Machete kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
|
||||
# raise an error if the user that this was built with an incompatible
|
||||
# CUDA version)
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
csrc/quantization/machete/machete_pytorch.cu)
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
@ -323,14 +420,31 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling moe extension.")
|
||||
|
||||
@ -133,10 +133,181 @@ macro(string_to_ver OUT_VER IN_STR)
|
||||
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
||||
# `CUDA_ARCH_FLAGS`.
|
||||
#
|
||||
# Example:
|
||||
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
||||
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
||||
# CMAKE_CUDA_FLAGS="-Wall"
|
||||
#
|
||||
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Extract unique CUDA architectures from a list of compute capabilities codes in
|
||||
# the form `<major><minor>[<letter>]`, convert them to the form sort
|
||||
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
|
||||
# stores them in `OUT_ARCHES`.
|
||||
#
|
||||
# Example:
|
||||
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
|
||||
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||
# OUT_ARCHES="7.5;...;9.0"
|
||||
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||
set(_CUDA_ARCHES)
|
||||
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||
if (_COMPUTE)
|
||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string_to_ver(_COMPUTE_VER ${_COMPUTE})
|
||||
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHES)
|
||||
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
|
||||
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# For a specific file set the `-gencode` flag in compile options conditionally
|
||||
# for the CUDA language.
|
||||
#
|
||||
# Example:
|
||||
# set_gencode_flag_for_srcs(
|
||||
# SRCS "foo.cu"
|
||||
# ARCH "compute_75"
|
||||
# CODE "sm_75")
|
||||
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
|
||||
# `foo.cu` (only for the CUDA language).
|
||||
#
|
||||
macro(set_gencode_flag_for_srcs)
|
||||
set(options)
|
||||
set(oneValueArgs ARCH CODE)
|
||||
set(multiValueArgs SRCS)
|
||||
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN} )
|
||||
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
|
||||
set_property(
|
||||
SOURCE ${arg_SRCS}
|
||||
APPEND PROPERTY
|
||||
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
|
||||
)
|
||||
|
||||
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
|
||||
endmacro(set_gencode_flag_for_srcs)
|
||||
|
||||
#
|
||||
# For a list of source files set the `-gencode` flags in the files specific
|
||||
# compile options (specifically for the CUDA language).
|
||||
#
|
||||
# arguments are:
|
||||
# SRCS: list of source files
|
||||
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
|
||||
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
|
||||
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
|
||||
# that is larger than BUILD_PTX_FOR_ARCH.
|
||||
#
|
||||
macro(set_gencode_flags_for_srcs)
|
||||
set(options)
|
||||
set(oneValueArgs BUILD_PTX_FOR_ARCH)
|
||||
set(multiValueArgs SRCS CUDA_ARCHS)
|
||||
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN} )
|
||||
|
||||
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_ARCH}"
|
||||
CODE "sm_${_ARCH}")
|
||||
endforeach()
|
||||
|
||||
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
|
||||
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
|
||||
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_PTX_ARCH}"
|
||||
CODE "compute_${_PTX_ARCH}")
|
||||
endif()
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes.
|
||||
# The loose intersection is defined as:
|
||||
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||
# where `<=` is the version comparison operator.
|
||||
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
||||
# 9.0a to the result.
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
|
||||
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||
|
||||
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
||||
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
|
||||
# less or eqault to ARCH
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
set(_TMP_ARCH ${_SRC_ARCH})
|
||||
else()
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
||||
# `GPU_ARCHES`.
|
||||
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
|
||||
# the architectures on a per file basis.
|
||||
#
|
||||
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
||||
#
|
||||
@ -174,109 +345,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
||||
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
||||
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
||||
endif()
|
||||
|
||||
elseif(${GPU_LANG} STREQUAL "CUDA")
|
||||
#
|
||||
# Setup/process CUDA arch flags.
|
||||
#
|
||||
# The torch cmake setup hardcodes the detected architecture flags in
|
||||
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
|
||||
# can't modified on a per-target basis.
|
||||
# So, all the `-gencode` flags need to be extracted and removed from
|
||||
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
|
||||
# Since it's not possible to use `target_compiler_options` for adding target
|
||||
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
|
||||
# must be used instead. This requires repackaging the architecture flags
|
||||
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
|
||||
#
|
||||
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
|
||||
# to one of the other nvcc options to specify architectures.
|
||||
#
|
||||
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
|
||||
# detected architectures.
|
||||
#
|
||||
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# If this error is triggered, it might mean that torch has changed how it sets
|
||||
# up nvcc architecture code generation flags.
|
||||
if (NOT _CUDA_ARCH_FLAGS)
|
||||
message(FATAL_ERROR
|
||||
"Could not find any architecture related code generation flags in "
|
||||
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
|
||||
endif()
|
||||
|
||||
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
|
||||
|
||||
# Initialize the architecture lists to empty.
|
||||
set(${GPU_ARCHES})
|
||||
|
||||
# Process each `gencode` flag.
|
||||
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
|
||||
# For each flag, extract the version number and whether it refers to PTX
|
||||
# or native code.
|
||||
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
|
||||
# for that match.
|
||||
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||
if (_COMPUTE)
|
||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
|
||||
if (_SM)
|
||||
set(_SM ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
|
||||
if (_CODE)
|
||||
set(_CODE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
# Make sure the virtual architecture can be matched.
|
||||
if (NOT _COMPUTE)
|
||||
message(FATAL_ERROR
|
||||
"Could not determine virtual architecture from: ${_ARCH}.")
|
||||
endif()
|
||||
|
||||
# One of sm_ or compute_ must exist.
|
||||
if ((NOT _SM) AND (NOT _CODE))
|
||||
message(FATAL_ERROR
|
||||
"Could not determine a codegen architecture from: ${_ARCH}.")
|
||||
endif()
|
||||
|
||||
if (_SM)
|
||||
# -real suffix let CMake to only generate elf code for the kernels.
|
||||
# we want this, otherwise the added ptx (default) will increase binary size.
|
||||
set(_VIRT "-real")
|
||||
set(_CODE_ARCH ${_SM})
|
||||
else()
|
||||
# -virtual suffix let CMake to generate ptx code for the kernels.
|
||||
set(_VIRT "-virtual")
|
||||
set(_CODE_ARCH ${_CODE})
|
||||
endif()
|
||||
|
||||
# Check if the current version is in the supported arch list.
|
||||
string_to_ver(_CODE_VER ${_CODE_ARCH})
|
||||
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
||||
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
|
||||
continue()
|
||||
endif()
|
||||
|
||||
# Add it to the arch list.
|
||||
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
|
||||
endforeach()
|
||||
endif()
|
||||
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
|
||||
endmacro()
|
||||
|
||||
#
|
||||
|
||||
@ -12,6 +12,11 @@
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
||||
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
|
||||
#include "core/exception.hpp"
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "core/registration.h"
|
||||
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
|
||||
@ -552,3 +553,7 @@ torch::Tensor marlin_gemm_moe(
|
||||
thread_n, sms, max_par, replicate_input, apply_weights);
|
||||
return c;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
|
||||
}
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
|
||||
int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights);
|
||||
@ -1,6 +1,5 @@
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
#include "marlin_moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// Apply topk softmax to the gating outputs.
|
||||
@ -18,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||
// conditionally compiled so impl registration is in source file
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
68
csrc/ops.h
68
csrc/ops.h
@ -90,63 +90,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _zeros, int64_t split_k_iters,
|
||||
int64_t thx, int64_t thy);
|
||||
|
||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||
|
||||
namespace machete {
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
vllm::ScalarTypeTorchPtr const& btype);
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
vllm::ScalarTypeTorchPtr const& btype,
|
||||
c10::optional<torch::Tensor> const& scales,
|
||||
c10::optional<torch::Tensor> const& zeros,
|
||||
c10::optional<int64_t> group_size,
|
||||
c10::optional<torch::Tensor> const& C,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule);
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
vllm::ScalarTypeTorchPtr const& btype);
|
||||
|
||||
}; // namespace machete
|
||||
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
torch::Tensor& perm, c10::SymInt size_k,
|
||||
c10::SymInt size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
c10::SymInt size_k, c10::SymInt size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
int64_t n);
|
||||
|
||||
@ -156,11 +101,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
@ -175,14 +115,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||
torch::Tensor const& b_q_weight,
|
||||
torch::Tensor const& s_tok,
|
||||
torch::Tensor const& s_ch,
|
||||
torch::Tensor const& s_group,
|
||||
torch::Tensor& workspace, int64_t size_m,
|
||||
int64_t size_n, int64_t size_k);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
#else
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
} else if (version_num >= 80) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
#else
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
} else if (version_num == 89) {
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
} else if (version_num >= 80) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
} else {
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@ -22,6 +22,8 @@
|
||||
#include "../gptq_marlin/marlin.cuh"
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
using namespace marlin;
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||
}
|
||||
@ -1,25 +1,6 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace marlin {
|
||||
|
||||
@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits) {
|
||||
@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
c10::SymInt size_k, c10::SymInt size_n,
|
||||
int64_t num_bits) {
|
||||
@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
||||
}
|
||||
@ -23,6 +23,8 @@
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||
}
|
||||
@ -1,26 +1,6 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace marlin {
|
||||
|
||||
@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
torch::Tensor& perm, c10::SymInt size_k,
|
||||
c10::SymInt size_n, int64_t num_bits) {
|
||||
@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
}
|
||||
@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_config: ImplConfig, num_impl_files=2):
|
||||
def create_sources(impl_config: ImplConfig, num_impl_files=1):
|
||||
sources = []
|
||||
|
||||
type_name = generate_type_signature(impl_config.type_config)
|
||||
|
||||
@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
|
||||
}
|
||||
|
||||
template <typename PrepackedLayoutB, typename InLayout>
|
||||
static void prepack_B(cudaStream_t stream,
|
||||
typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
||||
InLayout B_layout,
|
||||
typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
||||
static void prepack_B_template(
|
||||
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
||||
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
||||
using TileShapeNKL =
|
||||
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
|
||||
auto ilvd_NKbNbKL_to_offset =
|
||||
|
||||
@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
// Allocate output
|
||||
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
||||
|
||||
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
|
||||
static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||
prepack_B_template<PrepackedLayoutB>(
|
||||
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace vllm;
|
||||
@ -78,14 +80,16 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
}
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
ScalarTypeTorchPtr const& btype) {
|
||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
||||
vllm::ScalarTypeTorchPtr const& btype) {
|
||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||
});
|
||||
#else
|
||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("machete_prepack_B", &prepack_B);
|
||||
m.impl("machete_gemm", &gemm);
|
||||
m.impl("machete_supported_schedules", &supported_schedules);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "common/base.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include "common/mem.h"
|
||||
@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("marlin_gemm", &marlin_gemm);
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "../dense/common/base.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include "../dense/common/mem.h"
|
||||
@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
|
||||
#include "common/base.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "core/registration.h"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
||||
}
|
||||
|
||||
@ -167,7 +167,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
||||
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
@ -175,22 +175,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor b_scales, Tensor workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||
"int size_m, int size_n, int size_k) -> Tensor");
|
||||
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
ops.def("machete_supported_schedules", &machete::supported_schedules);
|
||||
ops.def(
|
||||
"machete_supported_schedules("
|
||||
" __torch__.torch.classes._core_C.ScalarType btype"
|
||||
") -> str[]");
|
||||
ops.def(
|
||||
"machete_gemm(Tensor A, Tensor B,"
|
||||
" __torch__.torch.classes._core_C.ScalarType btype,"
|
||||
" Tensor? scales, Tensor? zeros, int? group_size,"
|
||||
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
||||
"-> Tensor");
|
||||
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
|
||||
ops.def(
|
||||
"machete_prepack_B(Tensor B,"
|
||||
" __torch__.torch.classes._core_C.ScalarType btype)"
|
||||
"-> Tensor");
|
||||
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||
@ -202,21 +204,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||
"int size_m, int size_n, int size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
||||
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// gptq_marlin repack from GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
||||
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// awq_marlin repack from AWQ.
|
||||
ops.def(
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// Dequantization for GGML.
|
||||
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
||||
@ -237,7 +237,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int num_bits, int size_m, int size_n, "
|
||||
"int size_k) -> Tensor");
|
||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def(
|
||||
@ -245,7 +245,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, int size_m, int size_n, "
|
||||
"int size_k) -> Tensor");
|
||||
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
|
||||
311
tools/report_build_time_ninja.py
Normal file
311
tools/report_build_time_ninja.py
Normal file
@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2018 The Chromium Authors. All rights reserved.
|
||||
# Use of this source code is governed by a BSD-style license that can be
|
||||
# found in the LICENSE file.
|
||||
|
||||
# Modified version of: https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/refs/heads/main/post_build_ninja_summary.py
|
||||
"""Summarize the last ninja build, invoked with ninja's -C syntax.
|
||||
|
||||
> python3 tools/report_build_time_ninja.py -C build/..
|
||||
|
||||
Typical output looks like this:
|
||||
```
|
||||
Longest build steps for .cpp.o:
|
||||
1.0 weighted s to build ...torch_bindings.cpp.o (12.4 s elapsed time)
|
||||
2.0 weighted s to build ..._attn_c.dir/csrc... (23.5 s elapsed time)
|
||||
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
|
||||
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
|
||||
Longest build steps for .so (linking):
|
||||
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
|
||||
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
|
||||
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
|
||||
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
|
||||
Longest build steps for .cu.o:
|
||||
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
|
||||
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
|
||||
15.3 weighted s to build ...machete_mm_... (183.6 s elapsed time)
|
||||
15.3 weighted s to build ...machete_mm_... (183.7 s elapsed time)
|
||||
15.5 weighted s to build ...machete_mm_... (185.6 s elapsed time)
|
||||
15.5 weighted s to build ...machete_mm_... (185.9 s elapsed time)
|
||||
15.5 weighted s to build ...machete_mm_... (186.2 s elapsed time)
|
||||
37.4 weighted s to build ...scaled_mm_c3x.cu... (449.0 s elapsed time)
|
||||
43.9 weighted s to build ...scaled_mm_c2x.cu... (527.4 s elapsed time)
|
||||
344.8 weighted s to build ...attention_...cu.o (1087.2 s elapsed time)
|
||||
1110.0 s weighted time (10120.4 s elapsed time sum, 9.1x parallelism)
|
||||
134 build steps completed, average of 0.12/s
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import errno
|
||||
import fnmatch
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
# The number of long build times to report:
|
||||
long_count = 10
|
||||
# The number of long times by extension to report
|
||||
long_ext_count = 10
|
||||
|
||||
|
||||
class Target:
|
||||
"""Represents a single line read for a .ninja_log file."""
|
||||
|
||||
def __init__(self, start, end):
|
||||
"""Creates a target object by passing in the start/end times in seconds
|
||||
as a float."""
|
||||
self.start = start
|
||||
self.end = end
|
||||
# A list of targets, appended to by the owner of this object.
|
||||
self.targets = []
|
||||
self.weighted_duration = 0.0
|
||||
|
||||
def Duration(self):
|
||||
"""Returns the task duration in seconds as a float."""
|
||||
return self.end - self.start
|
||||
|
||||
def SetWeightedDuration(self, weighted_duration):
|
||||
"""Sets the duration, in seconds, passed in as a float."""
|
||||
self.weighted_duration = weighted_duration
|
||||
|
||||
def WeightedDuration(self):
|
||||
"""Returns the task's weighted duration in seconds as a float.
|
||||
|
||||
Weighted_duration takes the elapsed time of the task and divides it
|
||||
by how many other tasks were running at the same time. Thus, it
|
||||
represents the approximate impact of this task on the total build time,
|
||||
with serialized or serializing steps typically ending up with much
|
||||
longer weighted durations.
|
||||
weighted_duration should always be the same or shorter than duration.
|
||||
"""
|
||||
# Allow for modest floating-point errors
|
||||
epsilon = 0.000002
|
||||
if (self.weighted_duration > self.Duration() + epsilon):
|
||||
print('%s > %s?' % (self.weighted_duration, self.Duration()))
|
||||
assert (self.weighted_duration <= self.Duration() + epsilon)
|
||||
return self.weighted_duration
|
||||
|
||||
def DescribeTargets(self):
|
||||
"""Returns a printable string that summarizes the targets."""
|
||||
# Some build steps generate dozens of outputs - handle them sanely.
|
||||
# The max_length was chosen so that it can fit most of the long
|
||||
# single-target names, while minimizing word wrapping.
|
||||
result = ', '.join(self.targets)
|
||||
max_length = 65
|
||||
if len(result) > max_length:
|
||||
result = result[:max_length] + '...'
|
||||
return result
|
||||
|
||||
|
||||
# Copied with some modifications from ninjatracing
|
||||
def ReadTargets(log, show_all):
|
||||
"""Reads all targets from .ninja_log file |log_file|, sorted by duration.
|
||||
|
||||
The result is a list of Target objects."""
|
||||
header = log.readline()
|
||||
assert header == '# ninja log v5\n', \
|
||||
'unrecognized ninja log version %r' % header
|
||||
targets_dict = {}
|
||||
last_end_seen = 0.0
|
||||
for line in log:
|
||||
parts = line.strip().split('\t')
|
||||
if len(parts) != 5:
|
||||
# If ninja.exe is rudely halted then the .ninja_log file may be
|
||||
# corrupt. Silently continue.
|
||||
continue
|
||||
start, end, _, name, cmdhash = parts # Ignore restat.
|
||||
# Convert from integral milliseconds to float seconds.
|
||||
start = int(start) / 1000.0
|
||||
end = int(end) / 1000.0
|
||||
if not show_all and end < last_end_seen:
|
||||
# An earlier time stamp means that this step is the first in a new
|
||||
# build, possibly an incremental build. Throw away the previous
|
||||
# data so that this new build will be displayed independently.
|
||||
# This has to be done by comparing end times because records are
|
||||
# written to the .ninja_log file when commands complete, so end
|
||||
# times are guaranteed to be in order, but start times are not.
|
||||
targets_dict = {}
|
||||
target = None
|
||||
if cmdhash in targets_dict:
|
||||
target = targets_dict[cmdhash]
|
||||
if not show_all and (target.start != start or target.end != end):
|
||||
# If several builds in a row just run one or two build steps
|
||||
# then the end times may not go backwards so the last build may
|
||||
# not be detected as such. However in many cases there will be a
|
||||
# build step repeated in the two builds and the changed
|
||||
# start/stop points for that command, identified by the hash,
|
||||
# can be used to detect and reset the target dictionary.
|
||||
targets_dict = {}
|
||||
target = None
|
||||
if not target:
|
||||
targets_dict[cmdhash] = target = Target(start, end)
|
||||
last_end_seen = end
|
||||
target.targets.append(name)
|
||||
return list(targets_dict.values())
|
||||
|
||||
|
||||
def GetExtension(target, extra_patterns):
|
||||
"""Return the file extension that best represents a target.
|
||||
|
||||
For targets that generate multiple outputs it is important to return a
|
||||
consistent 'canonical' extension. Ultimately the goal is to group build steps
|
||||
by type."""
|
||||
for output in target.targets:
|
||||
if extra_patterns:
|
||||
for fn_pattern in extra_patterns.split(';'):
|
||||
if fnmatch.fnmatch(output, '*' + fn_pattern + '*'):
|
||||
return fn_pattern
|
||||
# Not a true extension, but a good grouping.
|
||||
if output.endswith('type_mappings'):
|
||||
extension = 'type_mappings'
|
||||
break
|
||||
|
||||
# Capture two extensions if present. For example: file.javac.jar should
|
||||
# be distinguished from file.interface.jar.
|
||||
root, ext1 = os.path.splitext(output)
|
||||
_, ext2 = os.path.splitext(root)
|
||||
extension = ext2 + ext1 # Preserve the order in the file name.
|
||||
|
||||
if len(extension) == 0:
|
||||
extension = '(no extension found)'
|
||||
|
||||
if ext1 in ['.pdb', '.dll', '.exe']:
|
||||
extension = 'PEFile (linking)'
|
||||
# Make sure that .dll and .exe are grouped together and that the
|
||||
# .dll.lib files don't cause these to be listed as libraries
|
||||
break
|
||||
if ext1 in ['.so', '.TOC']:
|
||||
extension = '.so (linking)'
|
||||
# Attempt to identify linking, avoid identifying as '.TOC'
|
||||
break
|
||||
# Make sure .obj files don't get categorized as mojo files
|
||||
if ext1 in ['.obj', '.o']:
|
||||
break
|
||||
# Jars are the canonical output of java targets.
|
||||
if ext1 == '.jar':
|
||||
break
|
||||
# Normalize all mojo related outputs to 'mojo'.
|
||||
if output.count('.mojom') > 0:
|
||||
extension = 'mojo'
|
||||
break
|
||||
return extension
|
||||
|
||||
|
||||
def SummarizeEntries(entries, extra_step_types):
|
||||
"""Print a summary of the passed in list of Target objects."""
|
||||
|
||||
# Create a list that is in order by time stamp and has entries for the
|
||||
# beginning and ending of each build step (one time stamp may have multiple
|
||||
# entries due to multiple steps starting/stopping at exactly the same time).
|
||||
# Iterate through this list, keeping track of which tasks are running at all
|
||||
# times. At each time step calculate a running total for weighted time so
|
||||
# that when each task ends its own weighted time can easily be calculated.
|
||||
task_start_stop_times = []
|
||||
|
||||
earliest = -1
|
||||
latest = 0
|
||||
total_cpu_time = 0
|
||||
for target in entries:
|
||||
if earliest < 0 or target.start < earliest:
|
||||
earliest = target.start
|
||||
if target.end > latest:
|
||||
latest = target.end
|
||||
total_cpu_time += target.Duration()
|
||||
task_start_stop_times.append((target.start, 'start', target))
|
||||
task_start_stop_times.append((target.end, 'stop', target))
|
||||
length = latest - earliest
|
||||
weighted_total = 0.0
|
||||
|
||||
# Sort by the time/type records and ignore |target|
|
||||
task_start_stop_times.sort(key=lambda times: times[:2])
|
||||
# Now we have all task start/stop times sorted by when they happen. If a
|
||||
# task starts and stops on the same time stamp then the start will come
|
||||
# first because of the alphabet, which is important for making this work
|
||||
# correctly.
|
||||
# Track the tasks which are currently running.
|
||||
running_tasks = {}
|
||||
# Record the time we have processed up to so we know how to calculate time
|
||||
# deltas.
|
||||
last_time = task_start_stop_times[0][0]
|
||||
# Track the accumulated weighted time so that it can efficiently be added
|
||||
# to individual tasks.
|
||||
last_weighted_time = 0.0
|
||||
# Scan all start/stop events.
|
||||
for event in task_start_stop_times:
|
||||
time, action_name, target = event
|
||||
# Accumulate weighted time up to now.
|
||||
num_running = len(running_tasks)
|
||||
if num_running > 0:
|
||||
# Update the total weighted time up to this moment.
|
||||
last_weighted_time += (time - last_time) / float(num_running)
|
||||
if action_name == 'start':
|
||||
# Record the total weighted task time when this task starts.
|
||||
running_tasks[target] = last_weighted_time
|
||||
if action_name == 'stop':
|
||||
# Record the change in the total weighted task time while this task
|
||||
# ran.
|
||||
weighted_duration = last_weighted_time - running_tasks[target]
|
||||
target.SetWeightedDuration(weighted_duration)
|
||||
weighted_total += weighted_duration
|
||||
del running_tasks[target]
|
||||
last_time = time
|
||||
assert (len(running_tasks) == 0)
|
||||
|
||||
# Warn if the sum of weighted times is off by more than half a second.
|
||||
if abs(length - weighted_total) > 500:
|
||||
print('Warning: Possible corrupt ninja log, results may be '
|
||||
'untrustworthy. Length = %.3f, weighted total = %.3f' %
|
||||
(length, weighted_total))
|
||||
|
||||
entries_by_ext = defaultdict(list)
|
||||
for target in entries:
|
||||
extension = GetExtension(target, extra_step_types)
|
||||
entries_by_ext[extension].append(target)
|
||||
|
||||
for key, values in entries_by_ext.items():
|
||||
print(' Longest build steps for %s:' % key)
|
||||
values.sort(key=lambda x: x.WeightedDuration())
|
||||
for target in values[-long_count:]:
|
||||
print(' %8.1f weighted s to build %s (%.1f s elapsed time)' %
|
||||
(target.WeightedDuration(), target.DescribeTargets(),
|
||||
target.Duration()))
|
||||
|
||||
print(' %.1f s weighted time (%.1f s elapsed time sum, %1.1fx '
|
||||
'parallelism)' %
|
||||
(length, total_cpu_time, total_cpu_time * 1.0 / length))
|
||||
print(' %d build steps completed, average of %1.2f/s' %
|
||||
(len(entries), len(entries) / (length)))
|
||||
|
||||
|
||||
def main():
|
||||
log_file = '.ninja_log'
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-C', dest='build_directory', help='Build directory.')
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--step-types',
|
||||
help='semicolon separated fnmatch patterns for build-step grouping')
|
||||
parser.add_argument('--log-file',
|
||||
help="specific ninja log file to analyze.")
|
||||
args, _extra_args = parser.parse_known_args()
|
||||
if args.build_directory:
|
||||
log_file = os.path.join(args.build_directory, log_file)
|
||||
if args.log_file:
|
||||
log_file = args.log_file
|
||||
if args.step_types:
|
||||
# Make room for the extra build types.
|
||||
global long_ext_count
|
||||
long_ext_count += len(args.step_types.split(';'))
|
||||
|
||||
try:
|
||||
with open(log_file, 'r') as log:
|
||||
entries = ReadTargets(log, False)
|
||||
SummarizeEntries(entries, args.step_types)
|
||||
except IOError:
|
||||
print('Log file %r not found, no build summary created.' % log_file)
|
||||
return errno.ENOENT
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
@ -32,6 +32,15 @@ def hint_on_error(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
except NotImplementedError as e:
|
||||
msg = (
|
||||
"Error in calling custom op %s: %s\n"
|
||||
"Not implemented or built, mostly likely because the current current device "
|
||||
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
|
||||
"incorrectly while building)")
|
||||
logger.error(msg, fn.__name__, e)
|
||||
raise NotImplementedError(msg % (fn.__name__, e)) from e
|
||||
except AttributeError as e:
|
||||
msg = (
|
||||
"Error in calling custom op %s: %s\n"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user