mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 00:05:46 +08:00
fix pynccl reduce_scatter (#23648)
Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
parent
6891205b16
commit
c7c80af084
@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_)
|
||||
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device=input_tensor.device)
|
||||
|
||||
if sizes is not None:
|
||||
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
|
||||
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
|
||||
else:
|
||||
pynccl_comm.reduce_scatter(output, input_)
|
||||
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user