diff --git a/vllm/env_override.py b/vllm/env_override.py index b0a061d2c4ed..2bede4963f96 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -4,17 +4,23 @@ import os import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + # set some common config/environment variables that should be set # for all processes created by vllm and all processes # that interact with vllm workers. # they are executed whenever `import vllm` is called. -if not os.path.exists('/dev/nvidia-caps-imex-channels'): - # normally, we disable NCCL_CUMEM_ENABLE because it - # will cost 1~2 GiB GPU memory with cudagraph+allreduce, - # see https://github.com/NVIDIA/nccl/issues/1234 - # for more details. - # However, NCCL requires NCCL_CUMEM_ENABLE to work with +if 'NCCL_CUMEM_ENABLE' in os.environ: + logger.warning( + "NCCL_CUMEM_ENABLE is set to %s, skipping override. " + "This may increase memory overhead with cudagraph+allreduce: " + "https://github.com/NVIDIA/nccl/issues/1234", + os.environ['NCCL_CUMEM_ENABLE']) +elif not os.path.exists('/dev/nvidia-caps-imex-channels'): + # NCCL requires NCCL_CUMEM_ENABLE to work with # multi-node NVLink, typically on GB200-NVL72 systems. # The ultimate way to detect multi-node NVLink is to use # NVML APIs, which are too expensive to call here.