mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 10:57:03 +08:00
94 lines
3.3 KiB
Python
94 lines
3.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import torch
|
|
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
|
|
class All2AllBase:
|
|
|
|
def __init__(self, cpu_group, model):
|
|
self.cpu_group = cpu_group
|
|
|
|
# compute some common properties
|
|
from vllm.distributed.parallel_state import (get_dp_group,
|
|
get_ep_group,
|
|
get_tp_group,
|
|
in_the_same_node_as)
|
|
|
|
# all2all lives in ep group, which is merged from dp and tp group
|
|
self.dp_group = get_dp_group()
|
|
self.tp_group = get_tp_group()
|
|
self.ep_group = get_ep_group()
|
|
self.dp_rank = self.dp_group.rank_in_group
|
|
self.dp_world_size = self.dp_group.world_size
|
|
|
|
# all2all communication often has separate implementations for
|
|
# intra-node and inter-node communication
|
|
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
|
|
self.internode = not self.intranode
|
|
|
|
def dispatch(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
raise NotImplementedError
|
|
|
|
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def destroy(self):
|
|
pass
|
|
|
|
|
|
class NaiveAll2All(All2AllBase):
|
|
"""
|
|
A naive implementation of all2all communication.
|
|
It uses all-reduce under the hood, which is not
|
|
efficient at all. The main purpose is for testing and
|
|
debugging.
|
|
"""
|
|
|
|
def __init__(self, cpu_group, model):
|
|
super().__init__(cpu_group, model)
|
|
|
|
def naive_multicast(self, x: torch.Tensor,
|
|
cu_tokens_across_dp_cpu: torch.Tensor):
|
|
assert (len(x.shape) == 2)
|
|
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
|
device=x.device,
|
|
dtype=x.dtype)
|
|
|
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
|
self.dp_rank - 1]
|
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
|
buffer[start:end, :].copy_(x)
|
|
for idx in range(self.dp_world_size):
|
|
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
|
end = cu_tokens_across_dp_cpu[idx]
|
|
self.dp_group.broadcast(buffer[start:end, :], idx)
|
|
|
|
return buffer
|
|
|
|
def dispatch(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
cu_tokens_across_dp_cpu = get_forward_context(
|
|
).dp_metadata.cu_tokens_across_dp_cpu
|
|
|
|
hidden_states = self.naive_multicast(hidden_states,
|
|
cu_tokens_across_dp_cpu)
|
|
router_logits = self.naive_multicast(router_logits,
|
|
cu_tokens_across_dp_cpu)
|
|
return hidden_states, router_logits
|
|
|
|
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
cu_tokens_across_dp_cpu = get_forward_context(
|
|
).dp_metadata.cu_tokens_across_dp_cpu
|
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
|
self.dp_rank - 1]
|
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
|
|
|
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
|
hidden_states = all_hidden_states[start:end, :]
|
|
return hidden_states
|
|
|
|
def destroy(self):
|
|
pass
|