mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 11:55:43 +08:00
[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)
This commit is contained in:
parent
3dbb215b38
commit
aeb37c2a72
222
CMakeLists.txt
222
CMakeLists.txt
@ -143,6 +143,19 @@ else()
|
|||||||
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# For cuda we want to be able to control which architectures we compile for on
|
||||||
|
# a per-file basis in order to cut down on compile time. So here we extract
|
||||||
|
# the set of architectures we want to compile for and remove the from the
|
||||||
|
# CMAKE_CUDA_FLAGS so that they are not applied globally.
|
||||||
|
#
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||||
|
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
|
||||||
|
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||||
# the supported versions for the current language.
|
# the supported versions for the current language.
|
||||||
@ -223,30 +236,89 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
|
||||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
|
||||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
|
||||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
|
||||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${VLLM_EXT_SRC}"
|
||||||
|
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||||
|
|
||||||
|
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||||
|
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||||
|
# are not supported by Machete yet.
|
||||||
|
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS})
|
||||||
|
if (MARLIN_ARCHS)
|
||||||
|
set(MARLIN_SRCS
|
||||||
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
|
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||||
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
|
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||||
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
|
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${MARLIN_SRCS}"
|
||||||
|
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||||
|
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||||
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
|
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||||
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
set_source_files_properties(
|
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
set_gencode_flags_for_srcs(
|
||||||
PROPERTIES
|
SRCS "${SRCS}"
|
||||||
COMPILE_FLAGS
|
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||||
"-gencode arch=compute_90a,code=sm_90a")
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
|
||||||
|
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||||
|
else()
|
||||||
|
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
|
||||||
|
# build any 3x kernels
|
||||||
|
set(SCALED_MM_3X_ARCHS)
|
||||||
|
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
||||||
|
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||||
|
"later if you intend on running FP8 quantized models on "
|
||||||
|
"Hopper.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||||
|
# kernels for the remaining archs that are not already built for 3x.
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||||
|
"7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}")
|
||||||
|
# subtract out the archs that are already built for 3x
|
||||||
|
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||||
|
if (SCALED_MM_2X_ARCHS)
|
||||||
|
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
|
||||||
|
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (SCALED_MM_3X_ARCHS)
|
||||||
|
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
|
||||||
|
" for and covered by scaled_mm_c3x")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
@ -254,47 +326,72 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# Machete kernels
|
# Machete kernels
|
||||||
|
|
||||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||||
|
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||||
#
|
#
|
||||||
# For the Machete kernels we automatically generate sources for various
|
# For the Machete kernels we automatically generate sources for various
|
||||||
# preselected input type pairs and schedules.
|
# preselected input type pairs and schedules.
|
||||||
# Generate sources:
|
# Generate sources:
|
||||||
execute_process(
|
set(MACHETE_GEN_SCRIPT
|
||||||
COMMAND ${CMAKE_COMMAND} -E env
|
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py)
|
||||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH)
|
||||||
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
|
|
||||||
RESULT_VARIABLE machete_generation_result
|
|
||||||
OUTPUT_VARIABLE machete_generation_output
|
|
||||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
|
||||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
|
||||||
)
|
|
||||||
|
|
||||||
if (NOT machete_generation_result EQUAL 0)
|
message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}")
|
||||||
message(FATAL_ERROR "Machete generation failed."
|
message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}")
|
||||||
" Result: \"${machete_generation_result}\""
|
|
||||||
"\nCheck the log for details: "
|
if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
|
||||||
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E env
|
||||||
|
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||||
|
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
|
||||||
|
RESULT_VARIABLE machete_generation_result
|
||||||
|
OUTPUT_VARIABLE machete_generation_output
|
||||||
|
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||||
|
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||||
|
)
|
||||||
|
|
||||||
|
if (NOT machete_generation_result EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Machete generation failed."
|
||||||
|
" Result: \"${machete_generation_result}\""
|
||||||
|
"\nCheck the log for details: "
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
||||||
|
else()
|
||||||
|
set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
|
||||||
|
CACHE STRING "Last run machete generate script hash" FORCE)
|
||||||
|
message(STATUS "Machete generation completed successfully.")
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
message(STATUS "Machete generation completed successfully.")
|
message(STATUS "Machete generation script has not changed, skipping generation.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Add machete generated sources
|
# Add machete generated sources
|
||||||
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
|
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
|
||||||
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
|
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
|
||||||
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
|
|
||||||
|
|
||||||
set_source_files_properties(
|
# forward compatible
|
||||||
${MACHETE_GEN_SOURCES}
|
set_gencode_flags_for_srcs(
|
||||||
PROPERTIES
|
SRCS "${MACHETE_GEN_SOURCES}"
|
||||||
COMPILE_FLAGS
|
CUDA_ARCHS "${MACHETE_ARCHS}")
|
||||||
"-gencode arch=compute_90a,code=sm_90a")
|
|
||||||
|
list(APPEND VLLM_EXT_SRC
|
||||||
|
csrc/quantization/machete/machete_pytorch.cu)
|
||||||
|
|
||||||
|
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||||
|
AND MACHETE_ARCHS)
|
||||||
|
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||||
|
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||||
|
"later if you intend on running w4a16 quantized models on "
|
||||||
|
"Hopper.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building Machete kernels as no compatible archs "
|
||||||
|
"found in CUDA target architectures")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
# if CUDA endif
|
||||||
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
|
|
||||||
# raise an error if the user that this was built with an incompatible
|
|
||||||
# CUDA version)
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
|
||||||
csrc/quantization/machete/machete_pytorch.cu)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
@ -323,14 +420,31 @@ set(VLLM_MOE_EXT_SRC
|
|||||||
"csrc/moe/torch_bindings.cpp"
|
"csrc/moe/torch_bindings.cpp"
|
||||||
"csrc/moe/topk_softmax_kernels.cu")
|
"csrc/moe/topk_softmax_kernels.cu")
|
||||||
|
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||||
|
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_MOE_EXT_SRC
|
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
if (MARLIN_MOE_ARCHS)
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
set(MARLIN_MOE_SRC
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||||
"csrc/moe/marlin_moe_ops.cu")
|
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||||
|
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||||
|
"csrc/moe/marlin_moe_ops.cu")
|
||||||
|
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${MARLIN_MOE_SRC}"
|
||||||
|
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||||
|
|
||||||
|
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
||||||
|
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "Enabling moe extension.")
|
message(STATUS "Enabling moe extension.")
|
||||||
|
|||||||
@ -133,10 +133,181 @@ macro(string_to_ver OUT_VER IN_STR)
|
|||||||
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
||||||
|
# `CUDA_ARCH_FLAGS`.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
||||||
|
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||||
|
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
||||||
|
# CMAKE_CUDA_FLAGS="-Wall"
|
||||||
|
#
|
||||||
|
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
||||||
|
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||||
|
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
||||||
|
${CMAKE_CUDA_FLAGS})
|
||||||
|
|
||||||
|
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||||
|
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||||
|
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||||
|
${CMAKE_CUDA_FLAGS})
|
||||||
|
endmacro()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Extract unique CUDA architectures from a list of compute capabilities codes in
|
||||||
|
# the form `<major><minor>[<letter>]`, convert them to the form sort
|
||||||
|
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
|
||||||
|
# stores them in `OUT_ARCHES`.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
|
||||||
|
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||||
|
# OUT_ARCHES="7.5;...;9.0"
|
||||||
|
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||||
|
set(_CUDA_ARCHES)
|
||||||
|
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
||||||
|
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||||
|
if (_COMPUTE)
|
||||||
|
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
string_to_ver(_COMPUTE_VER ${_COMPUTE})
|
||||||
|
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
list(REMOVE_DUPLICATES _CUDA_ARCHES)
|
||||||
|
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
#
|
||||||
|
# For a specific file set the `-gencode` flag in compile options conditionally
|
||||||
|
# for the CUDA language.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# set_gencode_flag_for_srcs(
|
||||||
|
# SRCS "foo.cu"
|
||||||
|
# ARCH "compute_75"
|
||||||
|
# CODE "sm_75")
|
||||||
|
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
|
||||||
|
# `foo.cu` (only for the CUDA language).
|
||||||
|
#
|
||||||
|
macro(set_gencode_flag_for_srcs)
|
||||||
|
set(options)
|
||||||
|
set(oneValueArgs ARCH CODE)
|
||||||
|
set(multiValueArgs SRCS)
|
||||||
|
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
||||||
|
"${multiValueArgs}" ${ARGN} )
|
||||||
|
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
|
||||||
|
set_property(
|
||||||
|
SOURCE ${arg_SRCS}
|
||||||
|
APPEND PROPERTY
|
||||||
|
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
|
||||||
|
endmacro(set_gencode_flag_for_srcs)
|
||||||
|
|
||||||
|
#
|
||||||
|
# For a list of source files set the `-gencode` flags in the files specific
|
||||||
|
# compile options (specifically for the CUDA language).
|
||||||
|
#
|
||||||
|
# arguments are:
|
||||||
|
# SRCS: list of source files
|
||||||
|
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
|
||||||
|
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
|
||||||
|
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
|
||||||
|
# that is larger than BUILD_PTX_FOR_ARCH.
|
||||||
|
#
|
||||||
|
macro(set_gencode_flags_for_srcs)
|
||||||
|
set(options)
|
||||||
|
set(oneValueArgs BUILD_PTX_FOR_ARCH)
|
||||||
|
set(multiValueArgs SRCS CUDA_ARCHS)
|
||||||
|
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
||||||
|
"${multiValueArgs}" ${ARGN} )
|
||||||
|
|
||||||
|
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||||
|
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||||
|
set_gencode_flag_for_srcs(
|
||||||
|
SRCS ${arg_SRCS}
|
||||||
|
ARCH "compute_${_ARCH}"
|
||||||
|
CODE "sm_${_ARCH}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||||
|
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
|
||||||
|
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
|
||||||
|
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
|
||||||
|
set_gencode_flag_for_srcs(
|
||||||
|
SRCS ${arg_SRCS}
|
||||||
|
ARCH "compute_${_PTX_ARCH}"
|
||||||
|
CODE "compute_${_PTX_ARCH}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endmacro()
|
||||||
|
|
||||||
|
#
|
||||||
|
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||||
|
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||||
|
# `TGT_CUDA_ARCHS` list of gencodes.
|
||||||
|
# The loose intersection is defined as:
|
||||||
|
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||||
|
# where `<=` is the version comparison operator.
|
||||||
|
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
||||||
|
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||||
|
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
||||||
|
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
||||||
|
# 9.0a to the result.
|
||||||
|
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
|
||||||
|
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
|
||||||
|
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
|
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||||
|
#
|
||||||
|
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
|
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||||
|
|
||||||
|
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
||||||
|
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
||||||
|
set(_CUDA_ARCHS)
|
||||||
|
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||||
|
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||||
|
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||||
|
set(_CUDA_ARCHS "9.0a")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
|
||||||
|
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
|
||||||
|
# less or eqault to ARCH
|
||||||
|
foreach(_ARCH ${CUDA_ARCHS})
|
||||||
|
set(_TMP_ARCH)
|
||||||
|
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||||
|
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||||
|
set(_TMP_ARCH ${_SRC_ARCH})
|
||||||
|
else()
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
if (_TMP_ARCH)
|
||||||
|
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||||
|
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||||
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
||||||
# `GPU_ARCHES`.
|
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
|
||||||
|
# the architectures on a per file basis.
|
||||||
#
|
#
|
||||||
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
||||||
#
|
#
|
||||||
@ -174,109 +345,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
|||||||
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
||||||
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
elseif(${GPU_LANG} STREQUAL "CUDA")
|
|
||||||
#
|
|
||||||
# Setup/process CUDA arch flags.
|
|
||||||
#
|
|
||||||
# The torch cmake setup hardcodes the detected architecture flags in
|
|
||||||
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
|
|
||||||
# can't modified on a per-target basis.
|
|
||||||
# So, all the `-gencode` flags need to be extracted and removed from
|
|
||||||
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
|
|
||||||
# Since it's not possible to use `target_compiler_options` for adding target
|
|
||||||
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
|
|
||||||
# must be used instead. This requires repackaging the architecture flags
|
|
||||||
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
|
|
||||||
#
|
|
||||||
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
|
|
||||||
# to one of the other nvcc options to specify architectures.
|
|
||||||
#
|
|
||||||
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
|
|
||||||
# detected architectures.
|
|
||||||
#
|
|
||||||
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
|
||||||
|
|
||||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
|
||||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
|
|
||||||
${CMAKE_CUDA_FLAGS})
|
|
||||||
|
|
||||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
|
||||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
|
||||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
|
||||||
${CMAKE_CUDA_FLAGS})
|
|
||||||
|
|
||||||
# If this error is triggered, it might mean that torch has changed how it sets
|
|
||||||
# up nvcc architecture code generation flags.
|
|
||||||
if (NOT _CUDA_ARCH_FLAGS)
|
|
||||||
message(FATAL_ERROR
|
|
||||||
"Could not find any architecture related code generation flags in "
|
|
||||||
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
|
||||||
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
|
|
||||||
|
|
||||||
# Initialize the architecture lists to empty.
|
|
||||||
set(${GPU_ARCHES})
|
|
||||||
|
|
||||||
# Process each `gencode` flag.
|
|
||||||
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
|
|
||||||
# For each flag, extract the version number and whether it refers to PTX
|
|
||||||
# or native code.
|
|
||||||
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
|
|
||||||
# for that match.
|
|
||||||
|
|
||||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
|
||||||
if (_COMPUTE)
|
|
||||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
|
|
||||||
if (_SM)
|
|
||||||
set(_SM ${CMAKE_MATCH_1})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
|
|
||||||
if (_CODE)
|
|
||||||
set(_CODE ${CMAKE_MATCH_1})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Make sure the virtual architecture can be matched.
|
|
||||||
if (NOT _COMPUTE)
|
|
||||||
message(FATAL_ERROR
|
|
||||||
"Could not determine virtual architecture from: ${_ARCH}.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# One of sm_ or compute_ must exist.
|
|
||||||
if ((NOT _SM) AND (NOT _CODE))
|
|
||||||
message(FATAL_ERROR
|
|
||||||
"Could not determine a codegen architecture from: ${_ARCH}.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (_SM)
|
|
||||||
# -real suffix let CMake to only generate elf code for the kernels.
|
|
||||||
# we want this, otherwise the added ptx (default) will increase binary size.
|
|
||||||
set(_VIRT "-real")
|
|
||||||
set(_CODE_ARCH ${_SM})
|
|
||||||
else()
|
|
||||||
# -virtual suffix let CMake to generate ptx code for the kernels.
|
|
||||||
set(_VIRT "-virtual")
|
|
||||||
set(_CODE_ARCH ${_CODE})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Check if the current version is in the supported arch list.
|
|
||||||
string_to_ver(_CODE_VER ${_CODE_ARCH})
|
|
||||||
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
|
||||||
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
|
|
||||||
continue()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Add it to the arch list.
|
|
||||||
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
|
|
||||||
endforeach()
|
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
|
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@ -12,6 +12,11 @@
|
|||||||
// could be a macro instead of a literal token.
|
// could be a macro instead of a literal token.
|
||||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||||
|
|
||||||
|
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||||
|
// could be a macro instead of a literal token.
|
||||||
|
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
||||||
|
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||||
|
|
||||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||||
// via python's import statement.
|
// via python's import statement.
|
||||||
#define REGISTER_EXTENSION(NAME) \
|
#define REGISTER_EXTENSION(NAME) \
|
||||||
|
|||||||
@ -27,6 +27,7 @@
|
|||||||
|
|
||||||
#include "core/exception.hpp"
|
#include "core/exception.hpp"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
#include "core/registration.h"
|
||||||
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||||
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||||
|
|
||||||
@ -552,3 +553,7 @@ torch::Tensor marlin_gemm_moe(
|
|||||||
thread_n, sms, max_par, replicate_input, apply_weights);
|
thread_n, sms, max_par, replicate_input, apply_weights);
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
|
||||||
|
}
|
||||||
|
|||||||
@ -1,15 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include "core/scalar_type.hpp"
|
|
||||||
|
|
||||||
torch::Tensor marlin_gemm_moe(
|
|
||||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
|
||||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
|
||||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
|
||||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
|
||||||
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
|
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
|
|
||||||
int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
|
||||||
bool replicate_input, bool apply_weights);
|
|
||||||
@ -1,6 +1,5 @@
|
|||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
#include "moe_ops.h"
|
#include "moe_ops.h"
|
||||||
#include "marlin_moe_ops.h"
|
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||||
// Apply topk softmax to the gating outputs.
|
// Apply topk softmax to the gating outputs.
|
||||||
@ -18,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
|
||||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||||
" -> Tensor");
|
" -> Tensor");
|
||||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
// conditionally compiled so impl registration is in source file
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
68
csrc/ops.h
68
csrc/ops.h
@ -90,63 +90,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
|||||||
torch::Tensor _zeros, int64_t split_k_iters,
|
torch::Tensor _zeros, int64_t split_k_iters,
|
||||||
int64_t thx, int64_t thy);
|
int64_t thx, int64_t thy);
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
|
||||||
|
|
||||||
namespace machete {
|
|
||||||
|
|
||||||
std::vector<std::string> supported_schedules(
|
|
||||||
vllm::ScalarTypeTorchPtr const& btype);
|
|
||||||
|
|
||||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
|
||||||
vllm::ScalarTypeTorchPtr const& btype,
|
|
||||||
c10::optional<torch::Tensor> const& scales,
|
|
||||||
c10::optional<torch::Tensor> const& zeros,
|
|
||||||
c10::optional<int64_t> group_size,
|
|
||||||
c10::optional<torch::Tensor> const& C,
|
|
||||||
c10::optional<double> alpha, c10::optional<double> beta,
|
|
||||||
c10::optional<std::string> schedule);
|
|
||||||
|
|
||||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
|
||||||
vllm::ScalarTypeTorchPtr const& btype);
|
|
||||||
|
|
||||||
}; // namespace machete
|
|
||||||
|
|
||||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& b_meta,
|
|
||||||
torch::Tensor& b_scales,
|
|
||||||
torch::Tensor& workspace,
|
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
|
||||||
int64_t size_m, int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
|
||||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
|
||||||
torch::Tensor& workspace,
|
|
||||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
|
||||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
|
||||||
bool is_k_full, bool has_zp,
|
|
||||||
bool use_fp32_reduce);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& perm, c10::SymInt size_k,
|
|
||||||
c10::SymInt size_n, int64_t num_bits);
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
|
||||||
int64_t size_n, int64_t num_bits);
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|
||||||
c10::SymInt size_k, c10::SymInt size_n,
|
|
||||||
int64_t num_bits);
|
|
||||||
|
|
||||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||||
int64_t n);
|
int64_t n);
|
||||||
|
|
||||||
@ -156,11 +101,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
|||||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||||
int64_t row);
|
int64_t row);
|
||||||
|
|
||||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
|
||||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
|
||||||
int64_t size_k);
|
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||||
@ -175,14 +115,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& azp_adj,
|
torch::Tensor const& azp_adj,
|
||||||
c10::optional<torch::Tensor> const& azp,
|
c10::optional<torch::Tensor> const& azp,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
|
||||||
torch::Tensor const& b_q_weight,
|
|
||||||
torch::Tensor const& s_tok,
|
|
||||||
torch::Tensor const& s_ch,
|
|
||||||
torch::Tensor const& s_group,
|
|
||||||
torch::Tensor& workspace, int64_t size_m,
|
|
||||||
int64_t size_n, int64_t size_k);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
|||||||
@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
c10::optional<torch::Tensor> const& bias);
|
c10::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
if (version_num >= 90) {
|
// Hopper
|
||||||
// Hopper
|
|
||||||
|
|
||||||
// Guard against compilation issues for sm90 kernels
|
// Guard against compilation issues for sm90 kernels
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
|
if (version_num >= 90) {
|
||||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||||
#else
|
return;
|
||||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
}
|
||||||
#endif
|
#endif
|
||||||
} else if (version_num == 89) {
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||||
|
if (version_num == 89) {
|
||||||
// Ada Lovelace
|
// Ada Lovelace
|
||||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||||
} else if (version_num >= 80) {
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version_num >= 80) {
|
||||||
// Ampere
|
// Ampere
|
||||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||||
} else {
|
return;
|
||||||
// Turing
|
|
||||||
TORCH_CHECK(version_num >= 75);
|
|
||||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Turing
|
||||||
|
TORCH_CHECK(version_num >= 75);
|
||||||
|
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
"currently bias dtype must match output dtype ", c.dtype());
|
"currently bias dtype must match output dtype ", c.dtype());
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
int32_t version_num = get_sm_version_num();
|
|
||||||
if (version_num >= 90) {
|
|
||||||
// Hopper
|
|
||||||
|
|
||||||
// Guard against compilation issues for sm90 kernels
|
int32_t version_num = get_sm_version_num();
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||||
|
if (version_num >= 90) {
|
||||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
#else
|
return;
|
||||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
}
|
||||||
#endif
|
#endif
|
||||||
} else if (version_num == 89) {
|
|
||||||
|
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||||
|
if (version_num == 89) {
|
||||||
// Ada Lovelace
|
// Ada Lovelace
|
||||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
} else if (version_num >= 80) {
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version_num >= 80) {
|
||||||
// Ampere
|
// Ampere
|
||||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
} else {
|
return;
|
||||||
// Turing
|
|
||||||
TORCH_CHECK(version_num >= 75);
|
|
||||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Turing
|
||||||
|
TORCH_CHECK(version_num >= 75);
|
||||||
|
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num);
|
||||||
}
|
}
|
||||||
@ -22,6 +22,8 @@
|
|||||||
#include "../gptq_marlin/marlin.cuh"
|
#include "../gptq_marlin/marlin.cuh"
|
||||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
using namespace marlin;
|
using namespace marlin;
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||||
|
}
|
||||||
@ -1,25 +1,6 @@
|
|||||||
#include "marlin.cuh"
|
#include "marlin.cuh"
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#include "core/registration.h"
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
|
||||||
__global__ void awq_marlin_repack_kernel(
|
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
|
||||||
int size_k, int size_n) {}
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits) {
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
|
||||||
return torch::empty({1, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
namespace marlin {
|
namespace marlin {
|
||||||
|
|
||||||
@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint32_t vals[8];
|
uint32_t vals[8];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
int cur_elem = tc_row + tc_offsets[i];
|
int cur_elem = tc_row + tc_offsets[i];
|
||||||
|
|
||||||
@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||||
|
|
||||||
uint32_t res = 0;
|
uint32_t res = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
res |= vals[pack_idx[i]] << (i * 4);
|
res |= vals[pack_idx[i]] << (i * 4);
|
||||||
}
|
}
|
||||||
@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
|
|
||||||
uint32_t res1 = 0;
|
uint32_t res1 = 0;
|
||||||
uint32_t res2 = 0;
|
uint32_t res2 = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||||
@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
}
|
}
|
||||||
|
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
};
|
};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||||
int n_tile_id = 0;
|
int n_tile_id = 0;
|
||||||
|
|
||||||
start_pipes(k_tile_id, n_tile_id);
|
start_pipes(k_tile_id, n_tile_id);
|
||||||
|
|
||||||
while (n_tile_id < n_tiles) {
|
while (n_tile_id < n_tiles) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||||
n_tile_id + pipe + repack_stages - 1);
|
n_tile_id + pipe + repack_stages - 1);
|
||||||
@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
|
|
||||||
} // namespace marlin
|
} // namespace marlin
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS) \
|
#define CALL_IF(NUM_BITS) \
|
||||||
else if (num_bits == NUM_BITS) { \
|
else if (num_bits == NUM_BITS) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||||
int64_t size_n, int64_t num_bits) {
|
int64_t size_n, int64_t num_bits) {
|
||||||
@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||||
c10::SymInt size_k, c10::SymInt size_n,
|
c10::SymInt size_k, c10::SymInt size_n,
|
||||||
int64_t num_bits) {
|
int64_t num_bits) {
|
||||||
@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|||||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||||
options);
|
options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("awq_marlin_repack", &awq_marlin_repack);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||||
|
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
||||||
|
}
|
||||||
@ -23,6 +23,8 @@
|
|||||||
#include "marlin_dtypes.cuh"
|
#include "marlin_dtypes.cuh"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
static_assert(std::is_same<scalar_t, half>::value || \
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||||
@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||||
|
}
|
||||||
@ -1,26 +1,6 @@
|
|||||||
#include "marlin.cuh"
|
#include "marlin.cuh"
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#include "core/registration.h"
|
||||||
|
|
||||||
namespace marlin {
|
|
||||||
|
|
||||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
|
||||||
__global__ void gptq_marlin_repack_kernel(
|
|
||||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
|
||||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
|
||||||
int size_k, int size_n) {}
|
|
||||||
|
|
||||||
} // namespace marlin
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|
||||||
int64_t size_k, int64_t size_n,
|
|
||||||
int64_t num_bits) {
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
|
||||||
return torch::empty({1, 1});
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
namespace marlin {
|
namespace marlin {
|
||||||
|
|
||||||
@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
uint32_t b1_vals[tile_ints];
|
uint32_t b1_vals[tile_ints];
|
||||||
uint32_t b2_vals[tile_ints];
|
uint32_t b2_vals[tile_ints];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < tile_ints; i++) {
|
for (int i = 0; i < tile_ints; i++) {
|
||||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
int cur_elem = tc_row + tc_offsets[i];
|
int cur_elem = tc_row + tc_offsets[i];
|
||||||
int cur_int = cur_elem / pack_factor;
|
int cur_int = cur_elem / pack_factor;
|
||||||
@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||||
|
|
||||||
uint32_t res = 0;
|
uint32_t res = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
res |= vals[pack_idx[i]] << (i * 4);
|
res |= vals[pack_idx[i]] << (i * 4);
|
||||||
}
|
}
|
||||||
@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
|
|
||||||
uint32_t res1 = 0;
|
uint32_t res1 = 0;
|
||||||
uint32_t res2 = 0;
|
uint32_t res2 = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||||
@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
}
|
}
|
||||||
|
|
||||||
wait_for_stage();
|
wait_for_stage();
|
||||||
};
|
};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||||
int n_tile_id = 0;
|
int n_tile_id = 0;
|
||||||
|
|
||||||
@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
start_pipes(k_tile_id, n_tile_id);
|
start_pipes(k_tile_id, n_tile_id);
|
||||||
|
|
||||||
while (n_tile_id < n_tiles) {
|
while (n_tile_id < n_tiles) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||||
n_tile_id + pipe + repack_stages - 1);
|
n_tile_id + pipe + repack_stages - 1);
|
||||||
@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
|
|
||||||
} // namespace marlin
|
} // namespace marlin
|
||||||
|
|
||||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||||
cudaFuncSetAttribute( \
|
cudaFuncSetAttribute( \
|
||||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||||
HAS_PERM>, \
|
HAS_PERM>, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||||
HAS_PERM> \
|
HAS_PERM> \
|
||||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& perm, c10::SymInt size_k,
|
torch::Tensor& perm, c10::SymInt size_k,
|
||||||
c10::SymInt size_n, int64_t num_bits) {
|
c10::SymInt size_n, int64_t num_bits) {
|
||||||
@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|||||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||||
options);
|
options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||||
|
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||||
|
}
|
||||||
@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
|||||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||||
|
|
||||||
|
|
||||||
def create_sources(impl_config: ImplConfig, num_impl_files=2):
|
def create_sources(impl_config: ImplConfig, num_impl_files=1):
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
type_name = generate_type_signature(impl_config.type_config)
|
type_name = generate_type_signature(impl_config.type_config)
|
||||||
|
|||||||
@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename PrepackedLayoutB, typename InLayout>
|
template <typename PrepackedLayoutB, typename InLayout>
|
||||||
static void prepack_B(cudaStream_t stream,
|
static void prepack_B_template(
|
||||||
typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
||||||
InLayout B_layout,
|
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
||||||
typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
|
||||||
using TileShapeNKL =
|
using TileShapeNKL =
|
||||||
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
|
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
|
||||||
auto ilvd_NKbNbKL_to_offset =
|
auto ilvd_NKbNbKL_to_offset =
|
||||||
|
|||||||
@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
|||||||
// Allocate output
|
// Allocate output
|
||||||
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
||||||
|
|
||||||
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
|
prepack_B_template<PrepackedLayoutB>(
|
||||||
static_cast<ElementB*>(D.mutable_data_ptr()));
|
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||||
|
|
||||||
return D;
|
return D;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
#include "machete_prepack_launcher.cuh"
|
#include "machete_prepack_launcher.cuh"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
namespace machete {
|
namespace machete {
|
||||||
|
|
||||||
using namespace vllm;
|
using namespace vllm;
|
||||||
@ -78,14 +80,16 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
|||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||||
ScalarTypeTorchPtr const& btype) {
|
vllm::ScalarTypeTorchPtr const& btype) {
|
||||||
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
|
|
||||||
return scalar_type_dispatch(*btype, [&](auto BType) {
|
return scalar_type_dispatch(*btype, [&](auto BType) {
|
||||||
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
|
||||||
});
|
});
|
||||||
#else
|
}
|
||||||
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
|
|
||||||
#endif
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("machete_prepack_B", &prepack_B);
|
||||||
|
m.impl("machete_gemm", &gemm);
|
||||||
|
m.impl("machete_supported_schedules", &supported_schedules);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // namespace machete
|
}; // namespace machete
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "common/base.h"
|
#include "common/base.h"
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
#include "common/mem.h"
|
#include "common/mem.h"
|
||||||
@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("marlin_gemm", &marlin_gemm);
|
||||||
|
}
|
||||||
|
|||||||
@ -30,6 +30,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "../dense/common/base.h"
|
#include "../dense/common/base.h"
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
#include "../dense/common/mem.h"
|
#include "../dense/common/mem.h"
|
||||||
@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
|||||||
|
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
|
||||||
|
}
|
||||||
|
|||||||
@ -28,6 +28,7 @@
|
|||||||
|
|
||||||
#include "common/base.h"
|
#include "common/base.h"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
||||||
|
}
|
||||||
|
|||||||
@ -167,7 +167,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
||||||
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
@ -175,22 +175,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor b_scales, Tensor workspace, "
|
"Tensor b_scales, Tensor workspace, "
|
||||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||||
"int size_m, int size_n, int size_k) -> Tensor");
|
"int size_m, int size_n, int size_k) -> Tensor");
|
||||||
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||||
ops.def("machete_supported_schedules", &machete::supported_schedules);
|
ops.def(
|
||||||
|
"machete_supported_schedules("
|
||||||
|
" __torch__.torch.classes._core_C.ScalarType btype"
|
||||||
|
") -> str[]");
|
||||||
ops.def(
|
ops.def(
|
||||||
"machete_gemm(Tensor A, Tensor B,"
|
"machete_gemm(Tensor A, Tensor B,"
|
||||||
" __torch__.torch.classes._core_C.ScalarType btype,"
|
" __torch__.torch.classes._core_C.ScalarType btype,"
|
||||||
" Tensor? scales, Tensor? zeros, int? group_size,"
|
" Tensor? scales, Tensor? zeros, int? group_size,"
|
||||||
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
" Tensor? C, float? alpha, float? beta, str? schedule)"
|
||||||
"-> Tensor");
|
"-> Tensor");
|
||||||
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
|
|
||||||
ops.def(
|
ops.def(
|
||||||
"machete_prepack_B(Tensor B,"
|
"machete_prepack_B(Tensor B,"
|
||||||
" __torch__.torch.classes._core_C.ScalarType btype)"
|
" __torch__.torch.classes._core_C.ScalarType btype)"
|
||||||
"-> Tensor");
|
"-> Tensor");
|
||||||
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||||
@ -202,21 +204,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||||
"int size_m, int size_n, int size_k, bool is_k_full, "
|
"int size_m, int size_n, int size_k, bool is_k_full, "
|
||||||
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
||||||
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// gptq_marlin repack from GPTQ.
|
// gptq_marlin repack from GPTQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
||||||
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
||||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
// conditionally compiled so impl registrations are in source file
|
||||||
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
|
|
||||||
|
|
||||||
// awq_marlin repack from AWQ.
|
// awq_marlin repack from AWQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||||
"SymInt size_n, int num_bits) -> Tensor");
|
"SymInt size_n, int num_bits) -> Tensor");
|
||||||
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
// conditionally compiled so impl registrations are in source file
|
||||||
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
|
|
||||||
|
|
||||||
// Dequantization for GGML.
|
// Dequantization for GGML.
|
||||||
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
||||||
@ -237,7 +237,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||||
"Tensor! workspace, int num_bits, int size_m, int size_n, "
|
"Tensor! workspace, int num_bits, int size_m, int size_n, "
|
||||||
"int size_k) -> Tensor");
|
"int size_k) -> Tensor");
|
||||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// marlin_qqq_gemm for QQQ.
|
// marlin_qqq_gemm for QQQ.
|
||||||
ops.def(
|
ops.def(
|
||||||
@ -245,7 +245,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||||
"Tensor! workspace, int size_m, int size_n, "
|
"Tensor! workspace, int size_m, int size_n, "
|
||||||
"int size_k) -> Tensor");
|
"int size_k) -> Tensor");
|
||||||
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
// quantization, as well as bias
|
// quantization, as well as bias
|
||||||
|
|||||||
311
tools/report_build_time_ninja.py
Normal file
311
tools/report_build_time_ninja.py
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2018 The Chromium Authors. All rights reserved.
|
||||||
|
# Use of this source code is governed by a BSD-style license that can be
|
||||||
|
# found in the LICENSE file.
|
||||||
|
|
||||||
|
# Modified version of: https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/refs/heads/main/post_build_ninja_summary.py
|
||||||
|
"""Summarize the last ninja build, invoked with ninja's -C syntax.
|
||||||
|
|
||||||
|
> python3 tools/report_build_time_ninja.py -C build/..
|
||||||
|
|
||||||
|
Typical output looks like this:
|
||||||
|
```
|
||||||
|
Longest build steps for .cpp.o:
|
||||||
|
1.0 weighted s to build ...torch_bindings.cpp.o (12.4 s elapsed time)
|
||||||
|
2.0 weighted s to build ..._attn_c.dir/csrc... (23.5 s elapsed time)
|
||||||
|
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
|
||||||
|
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
|
||||||
|
Longest build steps for .so (linking):
|
||||||
|
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
|
||||||
|
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
|
||||||
|
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
|
||||||
|
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
|
||||||
|
Longest build steps for .cu.o:
|
||||||
|
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
|
||||||
|
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
|
||||||
|
15.3 weighted s to build ...machete_mm_... (183.6 s elapsed time)
|
||||||
|
15.3 weighted s to build ...machete_mm_... (183.7 s elapsed time)
|
||||||
|
15.5 weighted s to build ...machete_mm_... (185.6 s elapsed time)
|
||||||
|
15.5 weighted s to build ...machete_mm_... (185.9 s elapsed time)
|
||||||
|
15.5 weighted s to build ...machete_mm_... (186.2 s elapsed time)
|
||||||
|
37.4 weighted s to build ...scaled_mm_c3x.cu... (449.0 s elapsed time)
|
||||||
|
43.9 weighted s to build ...scaled_mm_c2x.cu... (527.4 s elapsed time)
|
||||||
|
344.8 weighted s to build ...attention_...cu.o (1087.2 s elapsed time)
|
||||||
|
1110.0 s weighted time (10120.4 s elapsed time sum, 9.1x parallelism)
|
||||||
|
134 build steps completed, average of 0.12/s
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import errno
|
||||||
|
import fnmatch
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# The number of long build times to report:
|
||||||
|
long_count = 10
|
||||||
|
# The number of long times by extension to report
|
||||||
|
long_ext_count = 10
|
||||||
|
|
||||||
|
|
||||||
|
class Target:
|
||||||
|
"""Represents a single line read for a .ninja_log file."""
|
||||||
|
|
||||||
|
def __init__(self, start, end):
|
||||||
|
"""Creates a target object by passing in the start/end times in seconds
|
||||||
|
as a float."""
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
# A list of targets, appended to by the owner of this object.
|
||||||
|
self.targets = []
|
||||||
|
self.weighted_duration = 0.0
|
||||||
|
|
||||||
|
def Duration(self):
|
||||||
|
"""Returns the task duration in seconds as a float."""
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
def SetWeightedDuration(self, weighted_duration):
|
||||||
|
"""Sets the duration, in seconds, passed in as a float."""
|
||||||
|
self.weighted_duration = weighted_duration
|
||||||
|
|
||||||
|
def WeightedDuration(self):
|
||||||
|
"""Returns the task's weighted duration in seconds as a float.
|
||||||
|
|
||||||
|
Weighted_duration takes the elapsed time of the task and divides it
|
||||||
|
by how many other tasks were running at the same time. Thus, it
|
||||||
|
represents the approximate impact of this task on the total build time,
|
||||||
|
with serialized or serializing steps typically ending up with much
|
||||||
|
longer weighted durations.
|
||||||
|
weighted_duration should always be the same or shorter than duration.
|
||||||
|
"""
|
||||||
|
# Allow for modest floating-point errors
|
||||||
|
epsilon = 0.000002
|
||||||
|
if (self.weighted_duration > self.Duration() + epsilon):
|
||||||
|
print('%s > %s?' % (self.weighted_duration, self.Duration()))
|
||||||
|
assert (self.weighted_duration <= self.Duration() + epsilon)
|
||||||
|
return self.weighted_duration
|
||||||
|
|
||||||
|
def DescribeTargets(self):
|
||||||
|
"""Returns a printable string that summarizes the targets."""
|
||||||
|
# Some build steps generate dozens of outputs - handle them sanely.
|
||||||
|
# The max_length was chosen so that it can fit most of the long
|
||||||
|
# single-target names, while minimizing word wrapping.
|
||||||
|
result = ', '.join(self.targets)
|
||||||
|
max_length = 65
|
||||||
|
if len(result) > max_length:
|
||||||
|
result = result[:max_length] + '...'
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Copied with some modifications from ninjatracing
|
||||||
|
def ReadTargets(log, show_all):
|
||||||
|
"""Reads all targets from .ninja_log file |log_file|, sorted by duration.
|
||||||
|
|
||||||
|
The result is a list of Target objects."""
|
||||||
|
header = log.readline()
|
||||||
|
assert header == '# ninja log v5\n', \
|
||||||
|
'unrecognized ninja log version %r' % header
|
||||||
|
targets_dict = {}
|
||||||
|
last_end_seen = 0.0
|
||||||
|
for line in log:
|
||||||
|
parts = line.strip().split('\t')
|
||||||
|
if len(parts) != 5:
|
||||||
|
# If ninja.exe is rudely halted then the .ninja_log file may be
|
||||||
|
# corrupt. Silently continue.
|
||||||
|
continue
|
||||||
|
start, end, _, name, cmdhash = parts # Ignore restat.
|
||||||
|
# Convert from integral milliseconds to float seconds.
|
||||||
|
start = int(start) / 1000.0
|
||||||
|
end = int(end) / 1000.0
|
||||||
|
if not show_all and end < last_end_seen:
|
||||||
|
# An earlier time stamp means that this step is the first in a new
|
||||||
|
# build, possibly an incremental build. Throw away the previous
|
||||||
|
# data so that this new build will be displayed independently.
|
||||||
|
# This has to be done by comparing end times because records are
|
||||||
|
# written to the .ninja_log file when commands complete, so end
|
||||||
|
# times are guaranteed to be in order, but start times are not.
|
||||||
|
targets_dict = {}
|
||||||
|
target = None
|
||||||
|
if cmdhash in targets_dict:
|
||||||
|
target = targets_dict[cmdhash]
|
||||||
|
if not show_all and (target.start != start or target.end != end):
|
||||||
|
# If several builds in a row just run one or two build steps
|
||||||
|
# then the end times may not go backwards so the last build may
|
||||||
|
# not be detected as such. However in many cases there will be a
|
||||||
|
# build step repeated in the two builds and the changed
|
||||||
|
# start/stop points for that command, identified by the hash,
|
||||||
|
# can be used to detect and reset the target dictionary.
|
||||||
|
targets_dict = {}
|
||||||
|
target = None
|
||||||
|
if not target:
|
||||||
|
targets_dict[cmdhash] = target = Target(start, end)
|
||||||
|
last_end_seen = end
|
||||||
|
target.targets.append(name)
|
||||||
|
return list(targets_dict.values())
|
||||||
|
|
||||||
|
|
||||||
|
def GetExtension(target, extra_patterns):
|
||||||
|
"""Return the file extension that best represents a target.
|
||||||
|
|
||||||
|
For targets that generate multiple outputs it is important to return a
|
||||||
|
consistent 'canonical' extension. Ultimately the goal is to group build steps
|
||||||
|
by type."""
|
||||||
|
for output in target.targets:
|
||||||
|
if extra_patterns:
|
||||||
|
for fn_pattern in extra_patterns.split(';'):
|
||||||
|
if fnmatch.fnmatch(output, '*' + fn_pattern + '*'):
|
||||||
|
return fn_pattern
|
||||||
|
# Not a true extension, but a good grouping.
|
||||||
|
if output.endswith('type_mappings'):
|
||||||
|
extension = 'type_mappings'
|
||||||
|
break
|
||||||
|
|
||||||
|
# Capture two extensions if present. For example: file.javac.jar should
|
||||||
|
# be distinguished from file.interface.jar.
|
||||||
|
root, ext1 = os.path.splitext(output)
|
||||||
|
_, ext2 = os.path.splitext(root)
|
||||||
|
extension = ext2 + ext1 # Preserve the order in the file name.
|
||||||
|
|
||||||
|
if len(extension) == 0:
|
||||||
|
extension = '(no extension found)'
|
||||||
|
|
||||||
|
if ext1 in ['.pdb', '.dll', '.exe']:
|
||||||
|
extension = 'PEFile (linking)'
|
||||||
|
# Make sure that .dll and .exe are grouped together and that the
|
||||||
|
# .dll.lib files don't cause these to be listed as libraries
|
||||||
|
break
|
||||||
|
if ext1 in ['.so', '.TOC']:
|
||||||
|
extension = '.so (linking)'
|
||||||
|
# Attempt to identify linking, avoid identifying as '.TOC'
|
||||||
|
break
|
||||||
|
# Make sure .obj files don't get categorized as mojo files
|
||||||
|
if ext1 in ['.obj', '.o']:
|
||||||
|
break
|
||||||
|
# Jars are the canonical output of java targets.
|
||||||
|
if ext1 == '.jar':
|
||||||
|
break
|
||||||
|
# Normalize all mojo related outputs to 'mojo'.
|
||||||
|
if output.count('.mojom') > 0:
|
||||||
|
extension = 'mojo'
|
||||||
|
break
|
||||||
|
return extension
|
||||||
|
|
||||||
|
|
||||||
|
def SummarizeEntries(entries, extra_step_types):
|
||||||
|
"""Print a summary of the passed in list of Target objects."""
|
||||||
|
|
||||||
|
# Create a list that is in order by time stamp and has entries for the
|
||||||
|
# beginning and ending of each build step (one time stamp may have multiple
|
||||||
|
# entries due to multiple steps starting/stopping at exactly the same time).
|
||||||
|
# Iterate through this list, keeping track of which tasks are running at all
|
||||||
|
# times. At each time step calculate a running total for weighted time so
|
||||||
|
# that when each task ends its own weighted time can easily be calculated.
|
||||||
|
task_start_stop_times = []
|
||||||
|
|
||||||
|
earliest = -1
|
||||||
|
latest = 0
|
||||||
|
total_cpu_time = 0
|
||||||
|
for target in entries:
|
||||||
|
if earliest < 0 or target.start < earliest:
|
||||||
|
earliest = target.start
|
||||||
|
if target.end > latest:
|
||||||
|
latest = target.end
|
||||||
|
total_cpu_time += target.Duration()
|
||||||
|
task_start_stop_times.append((target.start, 'start', target))
|
||||||
|
task_start_stop_times.append((target.end, 'stop', target))
|
||||||
|
length = latest - earliest
|
||||||
|
weighted_total = 0.0
|
||||||
|
|
||||||
|
# Sort by the time/type records and ignore |target|
|
||||||
|
task_start_stop_times.sort(key=lambda times: times[:2])
|
||||||
|
# Now we have all task start/stop times sorted by when they happen. If a
|
||||||
|
# task starts and stops on the same time stamp then the start will come
|
||||||
|
# first because of the alphabet, which is important for making this work
|
||||||
|
# correctly.
|
||||||
|
# Track the tasks which are currently running.
|
||||||
|
running_tasks = {}
|
||||||
|
# Record the time we have processed up to so we know how to calculate time
|
||||||
|
# deltas.
|
||||||
|
last_time = task_start_stop_times[0][0]
|
||||||
|
# Track the accumulated weighted time so that it can efficiently be added
|
||||||
|
# to individual tasks.
|
||||||
|
last_weighted_time = 0.0
|
||||||
|
# Scan all start/stop events.
|
||||||
|
for event in task_start_stop_times:
|
||||||
|
time, action_name, target = event
|
||||||
|
# Accumulate weighted time up to now.
|
||||||
|
num_running = len(running_tasks)
|
||||||
|
if num_running > 0:
|
||||||
|
# Update the total weighted time up to this moment.
|
||||||
|
last_weighted_time += (time - last_time) / float(num_running)
|
||||||
|
if action_name == 'start':
|
||||||
|
# Record the total weighted task time when this task starts.
|
||||||
|
running_tasks[target] = last_weighted_time
|
||||||
|
if action_name == 'stop':
|
||||||
|
# Record the change in the total weighted task time while this task
|
||||||
|
# ran.
|
||||||
|
weighted_duration = last_weighted_time - running_tasks[target]
|
||||||
|
target.SetWeightedDuration(weighted_duration)
|
||||||
|
weighted_total += weighted_duration
|
||||||
|
del running_tasks[target]
|
||||||
|
last_time = time
|
||||||
|
assert (len(running_tasks) == 0)
|
||||||
|
|
||||||
|
# Warn if the sum of weighted times is off by more than half a second.
|
||||||
|
if abs(length - weighted_total) > 500:
|
||||||
|
print('Warning: Possible corrupt ninja log, results may be '
|
||||||
|
'untrustworthy. Length = %.3f, weighted total = %.3f' %
|
||||||
|
(length, weighted_total))
|
||||||
|
|
||||||
|
entries_by_ext = defaultdict(list)
|
||||||
|
for target in entries:
|
||||||
|
extension = GetExtension(target, extra_step_types)
|
||||||
|
entries_by_ext[extension].append(target)
|
||||||
|
|
||||||
|
for key, values in entries_by_ext.items():
|
||||||
|
print(' Longest build steps for %s:' % key)
|
||||||
|
values.sort(key=lambda x: x.WeightedDuration())
|
||||||
|
for target in values[-long_count:]:
|
||||||
|
print(' %8.1f weighted s to build %s (%.1f s elapsed time)' %
|
||||||
|
(target.WeightedDuration(), target.DescribeTargets(),
|
||||||
|
target.Duration()))
|
||||||
|
|
||||||
|
print(' %.1f s weighted time (%.1f s elapsed time sum, %1.1fx '
|
||||||
|
'parallelism)' %
|
||||||
|
(length, total_cpu_time, total_cpu_time * 1.0 / length))
|
||||||
|
print(' %d build steps completed, average of %1.2f/s' %
|
||||||
|
(len(entries), len(entries) / (length)))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
log_file = '.ninja_log'
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-C', dest='build_directory', help='Build directory.')
|
||||||
|
parser.add_argument(
|
||||||
|
'-s',
|
||||||
|
'--step-types',
|
||||||
|
help='semicolon separated fnmatch patterns for build-step grouping')
|
||||||
|
parser.add_argument('--log-file',
|
||||||
|
help="specific ninja log file to analyze.")
|
||||||
|
args, _extra_args = parser.parse_known_args()
|
||||||
|
if args.build_directory:
|
||||||
|
log_file = os.path.join(args.build_directory, log_file)
|
||||||
|
if args.log_file:
|
||||||
|
log_file = args.log_file
|
||||||
|
if args.step_types:
|
||||||
|
# Make room for the extra build types.
|
||||||
|
global long_ext_count
|
||||||
|
long_ext_count += len(args.step_types.split(';'))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(log_file, 'r') as log:
|
||||||
|
entries = ReadTargets(log, False)
|
||||||
|
SummarizeEntries(entries, args.step_types)
|
||||||
|
except IOError:
|
||||||
|
print('Log file %r not found, no build summary created.' % log_file)
|
||||||
|
return errno.ENOENT
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.exit(main())
|
||||||
@ -32,6 +32,15 @@ def hint_on_error(fn):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
except NotImplementedError as e:
|
||||||
|
msg = (
|
||||||
|
"Error in calling custom op %s: %s\n"
|
||||||
|
"Not implemented or built, mostly likely because the current current device "
|
||||||
|
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
|
||||||
|
"incorrectly while building)")
|
||||||
|
logger.error(msg, fn.__name__, e)
|
||||||
|
raise NotImplementedError(msg % (fn.__name__, e)) from e
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
msg = (
|
msg = (
|
||||||
"Error in calling custom op %s: %s\n"
|
"Error in calling custom op %s: %s\n"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user