Update PyTorch to 2.9.0+cu129 (#24994)

Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Huy Do 2025-10-21 14:20:18 -07:00 committed by GitHub
parent 250fb1b8ea
commit becb7de40b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 68 additions and 67 deletions

View File

@ -172,6 +172,8 @@ steps:
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
- tests/distributed/test_symm_mem_allreduce.py - tests/distributed/test_symm_mem_allreduce.py
commands: commands:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
# test with torchrun tp=2 and external_dp=2 # test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with torchrun tp=2 and pp=2 # test with torchrun tp=2 and pp=2
@ -349,7 +351,8 @@ steps:
- python3 offline_inference/basic/embed.py - python3 offline_inference/basic/embed.py
- python3 offline_inference/basic/score.py - python3 offline_inference/basic/score.py
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536
- label: Platform Tests (CUDA) # 4min - label: Platform Tests (CUDA) # 4min
timeout_in_minutes: 15 timeout_in_minutes: 15
@ -534,7 +537,7 @@ steps:
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved # we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment # TODO(jerryzh168): resolve the above comment
- uv pip install --system torchao==0.13.0 - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min - label: LM Eval Small Models # 53min
@ -975,6 +978,8 @@ steps:
- tests/v1/shutdown - tests/v1/shutdown
- tests/v1/worker/test_worker_memory_snapshot.py - tests/v1/worker/test_worker_memory_snapshot.py
commands: commands:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py

View File

@ -38,7 +38,7 @@ repos:
rev: 0.9.1 rev: 0.9.1
hooks: hooks:
- id: pip-compile - id: pip-compile
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28]
files: ^requirements/test\.(in|txt)$ files: ^requirements/test\.(in|txt)$
- repo: local - repo: local
hooks: hooks:

View File

@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from docker/Dockerfile.rocm # versions are derived from docker/Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0")
# #
# Try to find python package with an executable that exactly matches # Try to find python package with an executable that exactly matches

View File

@ -5,7 +5,7 @@
# docs/contributing/dockerfile/dockerfile.md and # docs/contributing/dockerfile/dockerfile.md and
# docs/assets/contributing/dockerfile-stages-dependency.png # docs/assets/contributing/dockerfile-stages-dependency.png
ARG CUDA_VERSION=12.8.1 ARG CUDA_VERSION=12.9.1
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
# By parameterizing the base images, we allow third-party to use their own # By parameterizing the base images, we allow third-party to use their own
@ -275,6 +275,7 @@ WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM ARG TARGETPLATFORM
# TODO (huydhn): There is no prebuilt gdrcopy package on 12.9 at the moment
ARG GDRCOPY_CUDA_VERSION=12.8 ARG GDRCOPY_CUDA_VERSION=12.8
# Keep in line with FINAL_BASE_IMAGE # Keep in line with FINAL_BASE_IMAGE
ARG GDRCOPY_OS_VERSION=Ubuntu22_04 ARG GDRCOPY_OS_VERSION=Ubuntu22_04
@ -360,6 +361,13 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
&& uv pip install --system dist/*.whl --verbose \ && uv pip install --system dist/*.whl --verbose \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# TODO (huydhn): Remove this once xformers is released for 2.9.0
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
. /etc/environment
export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
BASH
# Install FlashInfer pre-compiled kernel cache and binaries # Install FlashInfer pre-compiled kernel cache and binaries
# https://docs.flashinfer.ai/installation.html # https://docs.flashinfer.ai/installation.html
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
@ -426,6 +434,7 @@ ARG PYTHON_VERSION
ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_INDEX_URL UV_INDEX_URL
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
ARG PYTORCH_CUDA_INDEX_BASE_URL
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@ -438,7 +447,8 @@ ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -ge 12 ]; then \ if [ "$CUDA_MAJOR" -ge 12 ]; then \
uv pip install --system -r requirements/dev.txt; \ uv pip install --system -r requirements/dev.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
fi fi
# install development dependencies (for testing) # install development dependencies (for testing)

View File

@ -199,9 +199,13 @@ FROM base AS vllm-test-deps
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/cpu-test.in && \ cp requirements/test.in requirements/cpu-test.in && \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 119 KiB

View File

@ -87,7 +87,7 @@ is ineffective.
While ongoing efforts like <https://github.com/vllm-project/vllm/issues/17419> While ongoing efforts like <https://github.com/vllm-project/vllm/issues/17419>
address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH` address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH`
to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/long_build`)
when manually triggering a build on Buildkite. This branch accomplishes two things: when manually triggering a build on Buildkite. This branch accomplishes two things:
1. Increase the timeout limit to 10 hours so that the build doesn't time out. 1. Increase the timeout limit to 10 hours so that the build doesn't time out.
@ -100,35 +100,17 @@ to warm it up so that future builds are faster.
## Update dependencies ## Update dependencies
Several vLLM dependencies, such as FlashInfer, also depend on PyTorch and need Several vLLM dependencies like xFormers depend on PyTorch and need
to be updated accordingly. Rather than waiting for all of them to publish new to be updated accordingly. Rather than waiting for all of them to publish new
releases (which would take too much time), they can be built from releases (which would take too much time), they can be built from
source to unblock the update process. source to unblock the update process.
### FlashInfer
Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
```bash
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'
export FLASHINFER_ENABLE_SM90=1
uv pip install --system \
--no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
```
One caveat is that building FlashInfer from source adds approximately 30
minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a
public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release
team if you want to get the package published there.
### xFormers ### xFormers
Similar to FlashInfer, here is how to build and install xFormers from source:
```bash ```bash
export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX' export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
MAX_JOBS=16 uv pip install --system \ MAX_JOBS=16 uv pip install --system \
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
``` ```
## Update all the different vLLM platforms ## Update all the different vLLM platforms

View File

@ -6,7 +6,7 @@ requires = [
"packaging>=24.2", "packaging>=24.2",
"setuptools>=77.0.3,<80.0.0", "setuptools>=77.0.3,<80.0.0",
"setuptools-scm>=8.0", "setuptools-scm>=8.0",
"torch == 2.8.0", "torch == 2.9.0",
"wheel", "wheel",
"jinja2", "jinja2",
] ]

View File

@ -4,7 +4,7 @@ ninja
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<80.0.0 setuptools>=77.0.3,<80.0.0
setuptools-scm>=8 setuptools-scm>=8
torch==2.8.0 torch==2.9.0
wheel wheel
jinja2>=3.1.6 jinja2>=3.1.6
regex regex

View File

@ -5,11 +5,11 @@ numba == 0.61.2 # Required for N-gram speculative decoding
# Dependencies for NVIDIA GPUs # Dependencies for NVIDIA GPUs
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
torch==2.8.0 torch==2.9.0
torchaudio==2.8.0 torchaudio==2.9.0
# These must be updated alongside torch # These must be updated alongside torch
torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 # https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 # xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.4.1 flashinfer-python==0.4.1

View File

@ -1,12 +1,12 @@
# Common dependencies # Common dependencies
-r common.txt -r common.txt
--extra-index-url https://download.pytorch.org/whl/rocm6.3 --extra-index-url https://download.pytorch.org/whl/rocm6.4
torch==2.8.0 torch==2.9.0
torchvision==0.23.0 torchvision==0.24.0
torchaudio==2.8.0 torchaudio==2.9.0
triton==3.3.0 triton==3.5.0
cmake>=3.26.1,<4 cmake>=3.26.1,<4
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<80.0.0 setuptools>=77.0.3,<80.0.0

View File

@ -24,9 +24,9 @@ soundfile # required for audio tests
jiwer # required for audio tests jiwer # required for audio tests
tblib # for pickling test exceptions tblib # for pickling test exceptions
timm >=1.0.17 # required for internvl and gemma3n-mm test timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.8.0 torch==2.9.0
torchaudio==2.8.0 torchaudio==2.9.0
torchvision==0.23.0 torchvision==0.24.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test mistral_common[image,audio] >= 1.8.5 # required for voxtral test

View File

@ -1,5 +1,5 @@
# This file was autogenerated by uv via the following command: # This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28 # uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28
absl-py==2.1.0 absl-py==2.1.0
# via rouge-score # via rouge-score
accelerate==1.0.1 accelerate==1.0.1
@ -573,42 +573,44 @@ numpy==1.26.4
# tritonclient # tritonclient
# vocos # vocos
# xarray # xarray
nvidia-cublas-cu12==12.8.4.1 nvidia-cublas-cu12==12.9.1.4
# via # via
# nvidia-cudnn-cu12 # nvidia-cudnn-cu12
# nvidia-cusolver-cu12 # nvidia-cusolver-cu12
# torch # torch
nvidia-cuda-cupti-cu12==12.8.90 nvidia-cuda-cupti-cu12==12.9.79
# via torch # via torch
nvidia-cuda-nvrtc-cu12==12.8.93 nvidia-cuda-nvrtc-cu12==12.9.86
# via torch # via torch
nvidia-cuda-runtime-cu12==12.8.90 nvidia-cuda-runtime-cu12==12.9.79
# via torch # via torch
nvidia-cudnn-cu12==9.10.2.21 nvidia-cudnn-cu12==9.10.2.21
# via torch # via torch
nvidia-cufft-cu12==11.3.3.83 nvidia-cufft-cu12==11.4.1.4
# via torch # via torch
nvidia-cufile-cu12==1.13.1.3 nvidia-cufile-cu12==1.14.1.1
# via torch # via torch
nvidia-curand-cu12==10.3.9.90 nvidia-curand-cu12==10.3.10.19
# via torch # via torch
nvidia-cusolver-cu12==11.7.3.90 nvidia-cusolver-cu12==11.7.5.82
# via torch # via torch
nvidia-cusparse-cu12==12.5.8.93 nvidia-cusparse-cu12==12.5.10.65
# via # via
# nvidia-cusolver-cu12 # nvidia-cusolver-cu12
# torch # torch
nvidia-cusparselt-cu12==0.7.1 nvidia-cusparselt-cu12==0.7.1
# via torch # via torch
nvidia-nccl-cu12==2.27.3 nvidia-nccl-cu12==2.27.5
# via torch # via torch
nvidia-nvjitlink-cu12==12.8.93 nvidia-nvjitlink-cu12==12.9.86
# via # via
# nvidia-cufft-cu12 # nvidia-cufft-cu12
# nvidia-cusolver-cu12 # nvidia-cusolver-cu12
# nvidia-cusparse-cu12 # nvidia-cusparse-cu12
# torch # torch
nvidia-nvtx-cu12==12.8.90 nvidia-nvshmem-cu12==3.3.20
# via torch
nvidia-nvtx-cu12==12.9.79
# via torch # via torch
omegaconf==2.3.0 omegaconf==2.3.0
# via # via
@ -1017,7 +1019,6 @@ setuptools==77.0.3
# lightning-utilities # lightning-utilities
# pytablewriter # pytablewriter
# torch # torch
# triton
shapely==2.1.1 shapely==2.1.1
# via # via
# geopandas # geopandas
@ -1122,7 +1123,7 @@ tomli==2.2.1
# via schemathesis # via schemathesis
tomli-w==1.2.0 tomli-w==1.2.0
# via schemathesis # via schemathesis
torch==2.8.0+cu128 torch==2.9.0+cu129
# via # via
# -r requirements/test.in # -r requirements/test.in
# accelerate # accelerate
@ -1151,7 +1152,7 @@ torch==2.8.0+cu128
# torchvision # torchvision
# vector-quantize-pytorch # vector-quantize-pytorch
# vocos # vocos
torchaudio==2.8.0+cu128 torchaudio==2.9.0+cu129
# via # via
# -r requirements/test.in # -r requirements/test.in
# encodec # encodec
@ -1164,7 +1165,7 @@ torchmetrics==1.7.4
# pytorch-lightning # pytorch-lightning
# terratorch # terratorch
# torchgeo # torchgeo
torchvision==0.23.0+cu128 torchvision==0.24.0+cu129
# via # via
# -r requirements/test.in # -r requirements/test.in
# lightly # lightly
@ -1205,7 +1206,7 @@ transformers==4.56.2
# transformers-stream-generator # transformers-stream-generator
transformers-stream-generator==0.0.5 transformers-stream-generator==0.0.5
# via -r requirements/test.in # via -r requirements/test.in
triton==3.4.0 triton==3.5.0
# via torch # via torch
tritonclient==2.51.0 tritonclient==2.51.0
# via # via

View File

@ -109,7 +109,7 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
tensor_parallel_size=4, tensor_parallel_size=4,
trust_remote_code=True, trust_remote_code=True,
fully_sharded_loras=True, fully_sharded_loras=True,
gpu_memory_utilization=0.85, gpu_memory_utilization=0.8,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False, cudagraph_specialize_lora=False,
), ),

View File

@ -3,6 +3,7 @@
from typing import Any from typing import Any
import pytest import pytest
import torch._dynamo.config as dynamo_config
from vllm import SamplingParams from vllm import SamplingParams
@ -12,6 +13,7 @@ from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B" MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters.""" uni/multiproc executor, and various sampling parameters."""

View File

@ -7,18 +7,15 @@ set -euo pipefail
# Requires: curl, apt-get, root privileges # Requires: curl, apt-get, root privileges
if [[ $(id -u) -ne 0 ]]; then if [[ $(id -u) -ne 0 ]]; then
echo "Must be run as root" >&2 echo "Must be run as root" >&2
exit 1 exit 1
fi fi
if [[ $# -ne 3 ]]; then if [[ $# -ne 3 ]]; then
echo "Usage: $0 <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch(x64|aarch64)>" >&2 echo "Usage: $0 <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch(x64|aarch64)>" >&2
exit 1 exit 1
fi fi
OS_VER="$1" OS_VER="$1"
CUDA_VER="$2" CUDA_VER="$2"
UUARCH_RAW="$3" UUARCH_RAW="$3"
# Normalize/validate arch # Normalize/validate arch
case "${UUARCH_RAW,,}" in case "${UUARCH_RAW,,}" in
aarch64|arm64) aarch64|arm64)