mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 23:01:06 +08:00
[Build/CI] Fix CUDA 11.8 build (#17679)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
f8d2cc5f55
commit
6e588da0f4
@ -30,7 +30,11 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
|
|||||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||||
|
|
||||||
# Supported NVIDIA architectures.
|
# Supported NVIDIA architectures.
|
||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL)
|
||||||
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||||
|
else()
|
||||||
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Supported AMD GPU architectures.
|
# Supported AMD GPU architectures.
|
||||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
||||||
|
|||||||
@ -29,3 +29,5 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
bool moe_permute_unpermute_supported();
|
||||||
@ -5,6 +5,9 @@
|
|||||||
#include "permute_unpermute_kernels/dispatch.h"
|
#include "permute_unpermute_kernels/dispatch.h"
|
||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
|
|
||||||
|
// moe_permute kernels require at least CUDA 12.0
|
||||||
|
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||||
|
|
||||||
void moe_permute(
|
void moe_permute(
|
||||||
const torch::Tensor& input, // [n_token, hidden]
|
const torch::Tensor& input, // [n_token, hidden]
|
||||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||||
@ -127,6 +130,44 @@ void moe_unpermute(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||||
|
torch::Tensor& topk_ids,
|
||||||
|
const torch::Tensor& token_expert_indicies,
|
||||||
|
const std::optional<torch::Tensor>& expert_map,
|
||||||
|
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||||
|
const std::optional<int64_t>& align_block_size,
|
||||||
|
torch::Tensor& permuted_input,
|
||||||
|
torch::Tensor& expert_first_token_offset,
|
||||||
|
torch::Tensor& src_row_id2dst_row_id_map,
|
||||||
|
torch::Tensor& m_indices) {
|
||||||
|
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
void moe_unpermute(const torch::Tensor& input,
|
||||||
|
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
||||||
|
const torch::Tensor& token_expert_indicies,
|
||||||
|
const std::optional<torch::Tensor>& expert_map,
|
||||||
|
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||||
|
const std::optional<int64_t>& align_block_size,
|
||||||
|
torch::Tensor& permuted_input,
|
||||||
|
torch::Tensor& expert_first_token_offset,
|
||||||
|
torch::Tensor& src_row_id2dst_row_id_map,
|
||||||
|
torch::Tensor& m_indices) {
|
||||||
|
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool moe_permute_unpermute_supported() {
|
||||||
|
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("moe_permute", &moe_permute);
|
m.impl("moe_permute", &moe_permute);
|
||||||
m.impl("moe_unpermute", &moe_unpermute);
|
m.impl("moe_unpermute", &moe_unpermute);
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
|
|
||||||
#include "moe_permute_unpermute_kernel.h"
|
#include "moe_permute_unpermute_kernel.h"
|
||||||
|
|
||||||
|
// moe_permute kernels require at least CUDA 12.0
|
||||||
|
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||||
|
|
||||||
// CubKeyValueSorter definition begin
|
// CubKeyValueSorter definition begin
|
||||||
CubKeyValueSorter::CubKeyValueSorter()
|
CubKeyValueSorter::CubKeyValueSorter()
|
||||||
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||||
@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
|
|||||||
int num_experts) {
|
int num_experts) {
|
||||||
auto tidx = threadIdx.x;
|
auto tidx = threadIdx.x;
|
||||||
auto bidx = blockIdx.x;
|
auto bidx = blockIdx.x;
|
||||||
auto lidx = tidx & 31;
|
|
||||||
auto widx = tidx >> 5;
|
|
||||||
auto warp_count = (blockDim.x + 31) >> 5;
|
|
||||||
auto offset = bidx * blockDim.x;
|
auto offset = bidx * blockDim.x;
|
||||||
auto bound = min(offset + blockDim.x, size);
|
auto bound = min(offset + blockDim.x, size);
|
||||||
extern __shared__ int smem_expert_map[];
|
extern __shared__ int smem_expert_map[];
|
||||||
@ -227,3 +227,5 @@ void getMIndices(int64_t* expert_first_token_offset,
|
|||||||
num_local_expert, align_block_size);
|
num_local_expert, align_block_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|||||||
@ -77,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
||||||
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
||||||
"topk, Tensor! hidden_states)->()");
|
"topk, Tensor! hidden_states)->()");
|
||||||
// conditionally compiled so impl registration is in source file
|
|
||||||
|
m.def("moe_permute_unpermute_supported() -> bool");
|
||||||
|
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@ -123,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||||
// CUTLASS groped FP8 kernels need at least CUDA 12.3
|
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
|
||||||
// and SM90 (Hopper)
|
// and SM90 (Hopper)
|
||||||
|
|
||||||
#if defined CUDA_VERSION
|
#if defined CUDA_VERSION
|
||||||
|
|||||||
@ -263,8 +263,11 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
|||||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \
|
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \
|
||||||
else \
|
else \
|
||||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \
|
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \
|
||||||
fi && \
|
fi; \
|
||||||
export FLASHINFER_ENABLE_AOT=1; \
|
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
|
||||||
|
if [ "$CUDA_MAJOR" -lt 12 ]; then \
|
||||||
|
export FLASHINFER_ENABLE_SM90=0; \
|
||||||
|
fi; \
|
||||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \
|
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \
|
||||||
fi
|
fi
|
||||||
COPY examples examples
|
COPY examples examples
|
||||||
@ -275,7 +278,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
. /etc/environment && \
|
. /etc/environment && \
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
# Although we build Flashinfer with AOT mode, there's still
|
# Even when we build Flashinfer with AOT mode, there's still
|
||||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||||
# install build dependencies for JIT compilation.
|
# install build dependencies for JIT compilation.
|
||||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||||
@ -304,7 +307,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
|
|
||||||
# install development dependencies (for testing)
|
# install development dependencies (for testing)
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system -r requirements/dev.txt
|
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
|
||||||
|
if [ "$CUDA_MAJOR" -ge 12 ]; then \
|
||||||
|
uv pip install --system -r requirements/dev.txt; \
|
||||||
|
fi
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
# install development dependencies (for testing)
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
moe_permute, moe_unpermute)
|
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
NUM_EXPERTS = [16, 64]
|
NUM_EXPERTS = [16, 64]
|
||||||
@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
|
|||||||
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||||
n_expert: int, ep_size: int, dtype: torch.dtype,
|
n_expert: int, ep_size: int, dtype: torch.dtype,
|
||||||
align_block_size: Optional[int]):
|
align_block_size: Optional[int]):
|
||||||
|
if not moe_permute_unpermute_supported():
|
||||||
|
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||||
fill_invalid_expert = 0
|
fill_invalid_expert = 0
|
||||||
ep_rank = np.random.randint(0, ep_size)
|
ep_rank = np.random.randint(0, ep_size)
|
||||||
expert_map = None
|
expert_map = None
|
||||||
|
|||||||
@ -182,3 +182,7 @@ def moe_unpermute(
|
|||||||
expert_first_token_offset, n_expert,
|
expert_first_token_offset, n_expert,
|
||||||
n_local_expert, topk, hidden_states)
|
n_local_expert, topk, hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def moe_permute_unpermute_supported():
|
||||||
|
return torch.ops._moe_C.moe_permute_unpermute_supported()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user