mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 23:15:20 +08:00
[Bugfix] Fixes for FlashInfer's TORCH_CUDA_ARCH_LIST (#20136)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
parent
3dd359147d
commit
bdb84e26b0
@ -374,23 +374,44 @@ ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashin
|
|||||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||||
. /etc/environment && \
|
. /etc/environment
|
||||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
|
||||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
|
if [[ "$CUDA_VERSION" == 12.8* ]]; then
|
||||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \
|
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
|
||||||
else \
|
else
|
||||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \
|
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \
|
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
|
||||||
# Needed to build AOT kernels
|
# Needed to build AOT kernels
|
||||||
(cd flashinfer && \
|
(cd flashinfer && \
|
||||||
python3 -m flashinfer.aot && \
|
python3 -m flashinfer.aot && \
|
||||||
uv pip install --system --no-build-isolation . \
|
uv pip install --system --no-build-isolation . \
|
||||||
) && \
|
)
|
||||||
rm -rf flashinfer; \
|
rm -rf flashinfer
|
||||||
|
|
||||||
|
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
|
||||||
|
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||||
|
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||||
|
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||||
|
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||||
|
fi
|
||||||
|
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
|
||||||
|
|
||||||
|
git clone --depth 1 --recursive --shallow-submodules \
|
||||||
|
--branch v0.2.6.post1 \
|
||||||
|
https://github.com/flashinfer-ai/flashinfer.git flashinfer
|
||||||
|
|
||||||
|
pushd flashinfer
|
||||||
|
python3 -m flashinfer.aot
|
||||||
|
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
|
||||||
|
uv pip install --system --no-build-isolation .
|
||||||
|
popd
|
||||||
|
|
||||||
|
rm -rf flashinfer
|
||||||
fi \
|
fi \
|
||||||
fi
|
fi
|
||||||
|
BASH
|
||||||
COPY examples examples
|
COPY examples examples
|
||||||
COPY benchmarks benchmarks
|
COPY benchmarks benchmarks
|
||||||
COPY ./vllm/collect_env.py .
|
COPY ./vllm/collect_env.py .
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user