diff --git a/docker/Dockerfile b/docker/Dockerfile index a71b052bfca25..d1009fb4fb18b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -374,23 +374,44 @@ ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashin ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl" ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" ARG FLASHINFER_GIT_REF="v0.2.6.post1" -RUN --mount=type=cache,target=/root/.cache/uv \ -. /etc/environment && \ -if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ - # FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use - if [[ "$CUDA_VERSION" == 12.8* ]]; then \ - uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \ - else \ - export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \ - git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \ - # Needed to build AOT kernels - (cd flashinfer && \ - python3 -m flashinfer.aot && \ - uv pip install --system --no-build-isolation . \ - ) && \ - rm -rf flashinfer; \ - fi \ -fi +RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' + . /etc/environment + if [ "$TARGETPLATFORM" != "linux/arm64" ]; then + # FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use + if [[ "$CUDA_VERSION" == 12.8* ]]; then + uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} + else + export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' + git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive + # Needed to build AOT kernels + (cd flashinfer && \ + python3 -m flashinfer.aot && \ + uv pip install --system --no-build-isolation . \ + ) + rm -rf flashinfer + + # Default arches (skipping 10.0a and 12.0 since these need 12.8) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + if [[ "${CUDA_VERSION}" == 11.* ]]; then + TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + fi + echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}" + + git clone --depth 1 --recursive --shallow-submodules \ + --branch v0.2.6.post1 \ + https://github.com/flashinfer-ai/flashinfer.git flashinfer + + pushd flashinfer + python3 -m flashinfer.aot + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \ + uv pip install --system --no-build-isolation . + popd + + rm -rf flashinfer + fi \ + fi +BASH COPY examples examples COPY benchmarks benchmarks COPY ./vllm/collect_env.py .