diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 4973b57f76563..87fcb18b1c037 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -65,8 +65,6 @@ ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO -ARG FA_BRANCH -ARG FA_REPO RUN git clone ${PYTORCH_REPO} pytorch RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ pip install -r requirements.txt && git submodule update --init --recursive \ @@ -77,14 +75,20 @@ RUN git clone ${PYTORCH_VISION_REPO} vision RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl +RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ + && cp /app/vision/dist/*.whl /app/install + +FROM base AS build_fa +ARG FA_BRANCH +ARG FA_REPO +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl RUN git clone ${FA_REPO} RUN cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist -RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ - && cp /app/vision/dist/*.whl /app/install \ - && cp /app/flash-attention/dist/*.whl /app/install +RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install FROM base AS build_aiter ARG AITER_BRANCH @@ -103,6 +107,8 @@ FROM base AS debs RUN mkdir /app/debs RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ @@ -111,13 +117,7 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ cp /install/*.whl /app/debs FROM base AS final -RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ - pip install /install/*.whl -RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ - pip install /install/*.whl -RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ - pip install /install/*.whl -RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ +RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \ pip install /install/*.whl ARG BASE_IMAGE