From e1d54022385ac52a3c3c6c6a3359d93f5c2944d5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 17 Dec 2023 01:44:45 -0800 Subject: [PATCH] Fix all-reduce memory usage (#2151) --- vllm/worker/worker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e3babdc022a7..3e31737f2109 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -48,6 +48,14 @@ class Worker: self.gpu_cache = None def init_model(self, cupy_port: Optional[int] = None): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) # Env vars will be set by Ray.