mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 00:05:38 +08:00
[Build] Add OpenAI triton_kernels (#28788)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
49ef847aa8
commit
9912b8ccb8
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,6 +4,9 @@
|
|||||||
# vllm-flash-attn built from source
|
# vllm-flash-attn built from source
|
||||||
vllm/vllm_flash_attn/*
|
vllm/vllm_flash_attn/*
|
||||||
|
|
||||||
|
# OpenAI triton kernels copied from source
|
||||||
|
vllm/third_party/triton_kernels/*
|
||||||
|
|
||||||
# triton jit
|
# triton jit
|
||||||
.triton
|
.triton
|
||||||
|
|
||||||
|
|||||||
@ -1030,6 +1030,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# For CUDA and HIP builds also build the triton_kernels external package.
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
|
include(cmake/external_projects/triton_kernels.cmake)
|
||||||
|
endif()
|
||||||
|
|
||||||
# For CUDA we also build and ship some external projects.
|
# For CUDA we also build and ship some external projects.
|
||||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
include(cmake/external_projects/flashmla.cmake)
|
include(cmake/external_projects/flashmla.cmake)
|
||||||
|
|||||||
53
cmake/external_projects/triton_kernels.cmake
Normal file
53
cmake/external_projects/triton_kernels.cmake
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
|
||||||
|
|
||||||
|
set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
|
||||||
|
|
||||||
|
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
|
||||||
|
# be directly set to the triton_kernels python directory.
|
||||||
|
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
|
||||||
|
message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}")
|
||||||
|
FetchContent_Declare(
|
||||||
|
triton_kernels
|
||||||
|
SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
else()
|
||||||
|
set(TRITON_GIT "https://github.com/triton-lang/triton.git")
|
||||||
|
message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}")
|
||||||
|
FetchContent_Declare(
|
||||||
|
triton_kernels
|
||||||
|
# TODO (varun) : Fetch just the triton_kernels directory from Triton
|
||||||
|
GIT_REPOSITORY https://github.com/triton-lang/triton.git
|
||||||
|
GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG}
|
||||||
|
GIT_PROGRESS TRUE
|
||||||
|
SOURCE_SUBDIR python/triton_kernels/triton_kernels
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Fetch content
|
||||||
|
FetchContent_MakeAvailable(triton_kernels)
|
||||||
|
|
||||||
|
if (NOT triton_kernels_SOURCE_DIR)
|
||||||
|
message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
|
||||||
|
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/")
|
||||||
|
else()
|
||||||
|
set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}")
|
||||||
|
|
||||||
|
add_custom_target(triton_kernels)
|
||||||
|
|
||||||
|
# Ensure the vllm/third_party directory exists before installation
|
||||||
|
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")")
|
||||||
|
|
||||||
|
## Copy .py files to install directory.
|
||||||
|
install(DIRECTORY
|
||||||
|
${TRITON_KERNELS_PYTHON_DIR}
|
||||||
|
DESTINATION
|
||||||
|
vllm/third_party/triton_kernels/
|
||||||
|
COMPONENT triton_kernels
|
||||||
|
FILES_MATCHING PATTERN "*.py")
|
||||||
17
setup.py
17
setup.py
@ -299,6 +299,20 @@ class cmake_build_ext(build_ext):
|
|||||||
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||||
self.copy_file(file, dst_file)
|
self.copy_file(file, dst_file)
|
||||||
|
|
||||||
|
if _is_cuda() or _is_hip():
|
||||||
|
# copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
|
||||||
|
# to current directory so that they can be included in the editable
|
||||||
|
# build
|
||||||
|
print(
|
||||||
|
f"Copying {self.build_lib}/vllm/third_party/triton_kernels "
|
||||||
|
"to vllm/third_party/triton_kernels"
|
||||||
|
)
|
||||||
|
shutil.copytree(
|
||||||
|
f"{self.build_lib}/vllm/third_party/triton_kernels",
|
||||||
|
"vllm/third_party/triton_kernels",
|
||||||
|
dirs_exist_ok=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class precompiled_build_ext(build_ext):
|
class precompiled_build_ext(build_ext):
|
||||||
"""Disables extension building when using precompiled binaries."""
|
"""Disables extension building when using precompiled binaries."""
|
||||||
@ -633,6 +647,9 @@ ext_modules = []
|
|||||||
if _is_cuda() or _is_hip():
|
if _is_cuda() or _is_hip():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||||
|
# Optional since this doesn't get built (produce an .so file). This is just
|
||||||
|
# copying the relevant .py files from the source repository.
|
||||||
|
ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
|
||||||
|
|
||||||
if _is_hip():
|
if _is_hip():
|
||||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.import_utils import has_triton_kernels
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -15,6 +16,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||||
|
assert has_triton_kernels()
|
||||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||||
from triton_kernels.numerics import InFlexData
|
from triton_kernels.numerics import InFlexData
|
||||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||||
|
|||||||
@ -18,6 +18,10 @@ from typing import Any
|
|||||||
import regex as re
|
import regex as re
|
||||||
from typing_extensions import Never
|
from typing_extensions import Never
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: This function can be removed if transformer_modules classes are
|
# TODO: This function can be removed if transformer_modules classes are
|
||||||
# serialized by value when communicating between processes
|
# serialized by value when communicating between processes
|
||||||
@ -62,6 +66,35 @@ def import_pynvml():
|
|||||||
return pynvml
|
return pynvml
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def import_triton_kernels():
|
||||||
|
"""
|
||||||
|
For convenience, prioritize triton_kernels that is available in
|
||||||
|
`site-packages`. Use `vllm.third_party.triton_kernels` as a fall-back.
|
||||||
|
"""
|
||||||
|
if _has_module("triton_kernels"):
|
||||||
|
import triton_kernels
|
||||||
|
|
||||||
|
logger.debug_once(
|
||||||
|
f"Loading module triton_kernels from {triton_kernels.__file__}.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
elif _has_module("vllm.third_party.triton_kernels"):
|
||||||
|
import vllm.third_party.triton_kernels as triton_kernels
|
||||||
|
|
||||||
|
logger.debug_once(
|
||||||
|
f"Loading module triton_kernels from {triton_kernels.__file__}.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
sys.modules["triton_kernels"] = triton_kernels
|
||||||
|
else:
|
||||||
|
logger.info_once(
|
||||||
|
"triton_kernels unavailable in this build. "
|
||||||
|
"Please consider installing triton_kernels from "
|
||||||
|
"https://github.com/triton-lang/triton/tree/main/python/triton_kernels"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def import_from_path(module_name: str, file_path: str | os.PathLike):
|
def import_from_path(module_name: str, file_path: str | os.PathLike):
|
||||||
"""
|
"""
|
||||||
Import a Python file according to its file path.
|
Import a Python file according to its file path.
|
||||||
@ -397,7 +430,12 @@ def has_deep_gemm() -> bool:
|
|||||||
|
|
||||||
def has_triton_kernels() -> bool:
|
def has_triton_kernels() -> bool:
|
||||||
"""Whether the optional `triton_kernels` package is available."""
|
"""Whether the optional `triton_kernels` package is available."""
|
||||||
return _has_module("triton_kernels")
|
is_available = _has_module("triton_kernels") or _has_module(
|
||||||
|
"vllm.third_party.triton_kernels"
|
||||||
|
)
|
||||||
|
if is_available:
|
||||||
|
import_triton_kernels()
|
||||||
|
return is_available
|
||||||
|
|
||||||
|
|
||||||
def has_tilelang() -> bool:
|
def has_tilelang() -> bool:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user