mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 16:25:46 +08:00
48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import torch
|
|
|
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
|
get_tensor_model_parallel_world_size,
|
|
get_tensor_model_parallel_group,
|
|
)
|
|
|
|
|
|
def tensor_model_parallel_all_reduce(input_):
|
|
"""All-reduce the input tensor across model parallel group.
|
|
|
|
NOTE: This operation is applied in-place on the input tensor.
|
|
"""
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if get_tensor_model_parallel_world_size() == 1:
|
|
return input_
|
|
# All-reduce.
|
|
torch.distributed.all_reduce(input_,
|
|
group=get_tensor_model_parallel_group())
|
|
return input_
|
|
|
|
|
|
def tensor_model_parallel_all_gather(input_, dim=-1):
|
|
"""All-gather the input tensor across model parallel group."""
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
assert -input_.dim() <= dim < input_.dim(), (
|
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
input_size = input_.size()
|
|
# Allocate output tensor.
|
|
output_tensor = torch.empty((world_size, ) + input_size,
|
|
dtype=input_.dtype,
|
|
device=input_.device)
|
|
# All-gather.
|
|
torch.distributed.all_gather_into_tensor(
|
|
output_tensor, input_, group=get_tensor_model_parallel_group())
|
|
# Reshape
|
|
output_tensor = output_tensor.movedim(0, dim)
|
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
(world_size * input_size[dim], ) +
|
|
input_size[dim + 1:])
|
|
return output_tensor
|