fix pynccl reduce_scatter (#23648)

Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
yzds 2025-08-27 09:21:11 +08:00 committed by GitHub
parent 6891205b16
commit c7c80af084
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()