[misc] tune some env vars for GB200 (#16992)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-04-23 10:59:48 +08:00 committed by GitHub
parent 6bc1e30ef9
commit e1cf90e099
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,8 +8,21 @@ import torch
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
if not os.path.exists('/dev/nvidia-caps-imex-channels'):
# normally, we disable NCCL_CUMEM_ENABLE because it
# will cost 1~2 GiB GPU memory with cudagraph+allreduce,
# see https://github.com/NVIDIA/nccl/issues/1234
# for more details.
# However, NCCL requires NCCL_CUMEM_ENABLE to work with
# multi-node NVLink, typically on GB200-NVL72 systems.
# The ultimate way to detect multi-node NVLink is to use
# NVML APIs, which are too expensive to call here.
# As an approximation, we check the existence of
# /dev/nvidia-caps-imex-channels, used by
# multi-node NVLink to communicate across nodes.
# This will still cost some GPU memory, but it is worthwhile
# because we can get very fast cross-node bandwidth with NVLink.
os.environ['NCCL_CUMEM_ENABLE'] = '0'
# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()