mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:26:13 +08:00
298 lines
10 KiB
Python
298 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import threading
|
|
from weakref import WeakValueDictionary
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
class Cache:
|
|
def __init__(self):
|
|
self._cache: WeakValueDictionary = WeakValueDictionary()
|
|
self._lock = threading.RLock() # Reentrant lock for thread safety
|
|
|
|
def get_or_create(self, kwargs, func):
|
|
# Create a hashable key from the kwargs
|
|
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
|
|
|
with self._lock:
|
|
instance = self._cache.get(key)
|
|
if instance is None:
|
|
instance = func(**kwargs)
|
|
self._cache[key] = instance
|
|
return instance
|
|
|
|
|
|
class All2AllManagerBase:
|
|
rank: int
|
|
world_size: int
|
|
|
|
def __init__(self, cpu_group):
|
|
self.cpu_group = cpu_group
|
|
|
|
# compute some common properties
|
|
from vllm.distributed.parallel_state import (
|
|
get_dp_group,
|
|
get_tp_group,
|
|
in_the_same_node_as,
|
|
)
|
|
|
|
# all2all lives in ep group, which is merged from dp and tp group
|
|
self.dp_group = get_dp_group()
|
|
self.tp_group = get_tp_group()
|
|
|
|
# no self.ep_group since self.ep_group is still in construction
|
|
# when we create this object
|
|
self.dp_rank = self.dp_group.rank_in_group
|
|
self.dp_world_size = self.dp_group.world_size
|
|
self.rank = dist.get_rank(cpu_group)
|
|
self.world_size = dist.get_world_size(cpu_group)
|
|
|
|
# all2all communication often has separate implementations for
|
|
# intra-node and inter-node communication
|
|
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
|
|
|
def get_handle(self, kwargs):
|
|
# get a handle for the all2all communication,
|
|
# based on the kwargs.
|
|
# different layers can have different configs,
|
|
# e.g. one layer has hidden size 1024, another has 2048.
|
|
# usually the underlying implementation caches the handle
|
|
# and reuse it for the same config.
|
|
raise NotImplementedError
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
is_sequence_parallel: bool = False,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def set_num_sms(self, num_sms: int):
|
|
pass
|
|
|
|
def max_sms_used(self) -> int | None:
|
|
return None # None means it could use the whole GPU
|
|
|
|
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
|
|
raise NotImplementedError
|
|
|
|
def destroy(self):
|
|
pass
|
|
|
|
|
|
class DeviceCommunicatorBase:
|
|
"""
|
|
Base class for device-specific communicator.
|
|
It can use the `cpu_group` to initialize the communicator.
|
|
If the device has PyTorch integration (PyTorch can recognize its
|
|
communication backend), the `device_group` will also be given.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cpu_group: ProcessGroup,
|
|
device: torch.device | None = None,
|
|
device_group: ProcessGroup | None = None,
|
|
unique_name: str = "",
|
|
):
|
|
self.device = device or torch.device("cpu")
|
|
self.cpu_group = cpu_group
|
|
self.device_group = device_group
|
|
self.unique_name = unique_name
|
|
self.rank = dist.get_rank(cpu_group)
|
|
self.world_size = dist.get_world_size(cpu_group)
|
|
self.ranks = dist.get_process_group_ranks(cpu_group)
|
|
self.global_rank = dist.get_rank()
|
|
self.global_world_size = dist.get_world_size()
|
|
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
|
|
|
use_ep = False
|
|
all2all_backend = None
|
|
from vllm.config import get_current_vllm_config
|
|
|
|
config = get_current_vllm_config()
|
|
if config is not None:
|
|
# as long as we use data parallel (coupled data parallel
|
|
# where all data parallel ranks execute forward together),
|
|
# we initialize the all2all manager used in expert parallel.
|
|
use_ep = config.parallel_config.data_parallel_size > 1
|
|
all2all_backend = config.parallel_config.all2all_backend
|
|
|
|
self.is_ep_communicator = "ep" in unique_name
|
|
self.use_all2all = self.is_ep_communicator and use_ep
|
|
self.all2all_backend = all2all_backend
|
|
self.all2all_manager: All2AllManagerBase | None = None
|
|
|
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
dist.all_reduce(input_, group=self.device_group)
|
|
return input_
|
|
|
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
input_size = input_.size()
|
|
# NOTE: we have to use concat-style all-gather here,
|
|
# stack-style all-gather has compatibility issues with
|
|
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
|
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
|
# Allocate output tensor.
|
|
output_tensor = torch.empty(
|
|
output_size, dtype=input_.dtype, device=input_.device
|
|
)
|
|
# All-gather.
|
|
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
|
# Reshape
|
|
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
|
output_tensor = output_tensor.movedim(0, dim)
|
|
output_tensor = output_tensor.reshape(
|
|
input_size[:dim]
|
|
+ (self.world_size * input_size[dim],)
|
|
+ input_size[dim + 1 :]
|
|
)
|
|
return output_tensor
|
|
|
|
def all_gatherv(
|
|
self,
|
|
input_: torch.Tensor | list[torch.Tensor],
|
|
dim: int = 0,
|
|
sizes: list[int] | None = None,
|
|
) -> torch.Tensor | list[torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
world_size = self.world_size
|
|
# Bypass the function if we are using only 1 GPU.
|
|
if world_size == 1:
|
|
return input_
|
|
assert -input_.dim() <= dim < input_.dim(), (
|
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
|
)
|
|
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
|
|
# Note: This will produce an incorrect answer if we don't make
|
|
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
|
input_tensor = input_.movedim(0, dim).contiguous()
|
|
|
|
assert input_tensor.shape[0] % world_size == 0
|
|
chunk_size = input_tensor.shape[0] // world_size
|
|
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
|
|
|
output_tensor = torch.empty(
|
|
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
|
)
|
|
|
|
# Perform reduce-scatter operation
|
|
torch.distributed.reduce_scatter_tensor(
|
|
output_tensor, input_tensor, group=self.device_group
|
|
)
|
|
|
|
# Reshape before returning
|
|
return output_tensor.movedim(0, dim).contiguous()
|
|
|
|
def reduce_scatterv(
|
|
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def gather(
|
|
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
|
) -> torch.Tensor | None:
|
|
"""
|
|
NOTE: We assume that the input tensor is on the same device across
|
|
all the ranks.
|
|
NOTE: `dst` is the local rank of the destination rank.
|
|
"""
|
|
world_size = self.world_size
|
|
assert -input_.dim() <= dim < input_.dim(), (
|
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
|
)
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
|
|
# Allocate output tensor.
|
|
if self.rank_in_group == dst:
|
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
else:
|
|
gather_list = None
|
|
# Gather.
|
|
torch.distributed.gather(
|
|
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
|
)
|
|
if self.rank_in_group == dst:
|
|
output_tensor = torch.cat(gather_list, dim=dim)
|
|
else:
|
|
output_tensor = None
|
|
return output_tensor
|
|
|
|
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
|
"""Sends a tensor to the destination rank in a blocking way"""
|
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
|
if dst is None:
|
|
dst = (self.rank_in_group + 1) % self.world_size
|
|
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
|
|
|
def recv(
|
|
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
|
) -> torch.Tensor:
|
|
"""Receives a tensor from the source rank."""
|
|
"""NOTE: `src` is the local rank of the source rank."""
|
|
if src is None:
|
|
src = (self.rank_in_group - 1) % self.world_size
|
|
|
|
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
|
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
|
return tensor
|
|
|
|
def destroy(self):
|
|
pass
|
|
|
|
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
|
|
"""
|
|
Prepare the communication buffer for the model.
|
|
"""
|
|
if not self.is_ep_communicator:
|
|
return
|
|
|
|
moe_modules = [
|
|
module
|
|
for module in model.modules()
|
|
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
|
# presence of quant_method.maybe_init_modular_kernel?
|
|
if (
|
|
module.__class__.__name__ == "FusedMoE"
|
|
or module.__class__.__name__ == "SharedFusedMoE"
|
|
)
|
|
]
|
|
for module in moe_modules:
|
|
module.maybe_init_modular_kernel()
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
is_sequence_parallel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Dispatch the hidden states and router logits to the appropriate device.
|
|
This is a no-op in the base class.
|
|
"""
|
|
return hidden_states, router_logits
|
|
|
|
def combine(
|
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
|
) -> torch.Tensor:
|
|
"""
|
|
Combine the hidden states and router logits from the appropriate device.
|
|
This is a no-op in the base class.
|
|
"""
|
|
return hidden_states
|