mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 18:44:31 +08:00
fix pynccl reduce_scatter (#23648)
Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
parent
6891205b16
commit
c7c80af084
@ -152,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
dtype=input_tensor.dtype,
|
dtype=input_tensor.dtype,
|
||||||
device=input_tensor.device)
|
device=input_tensor.device)
|
||||||
|
|
||||||
pynccl_comm.reduce_scatter(output, input_)
|
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||||
|
|
||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
@ -186,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
device=input_tensor.device)
|
device=input_tensor.device)
|
||||||
|
|
||||||
if sizes is not None:
|
if sizes is not None:
|
||||||
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
|
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
|
||||||
else:
|
else:
|
||||||
pynccl_comm.reduce_scatter(output, input_)
|
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||||
|
|
||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user