diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 0ea8de2f36f4b..eef3f9f75f9f1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): dtype=input_tensor.dtype, device=input_tensor.device) - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): device=input_tensor.device) if sizes is not None: - pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: - pynccl_comm.reduce_scatter(output, input_) + pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous()