mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:35:01 +08:00
491 lines
16 KiB
Python
491 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed import get_dp_group, get_ep_group
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.flashinfer import has_flashinfer_all2all
|
|
from vllm.utils.import_utils import has_deep_ep, has_pplx
|
|
|
|
from .base_device_communicator import All2AllManagerBase, Cache
|
|
|
|
if has_flashinfer_all2all():
|
|
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
|
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
|
from flashinfer.comm.trtllm_alltoall import (
|
|
MnnvlMoe, # type: ignore[import-not-found]
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class NaiveAll2AllManager(All2AllManagerBase):
|
|
"""
|
|
A naive implementation of all2all communication.
|
|
It uses all-reduce under the hood, which is not
|
|
efficient at all. The main purpose is for testing and
|
|
debugging.
|
|
"""
|
|
|
|
def __init__(self, cpu_group):
|
|
super().__init__(cpu_group)
|
|
|
|
def naive_multicast(
|
|
self,
|
|
x: torch.Tensor,
|
|
cu_tokens_across_sp_cpu: torch.Tensor,
|
|
is_sequence_parallel: bool,
|
|
) -> torch.Tensor:
|
|
assert len(x.shape) == 2
|
|
buffer = torch.empty(
|
|
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
|
|
)
|
|
|
|
rank = self.rank if is_sequence_parallel else self.dp_rank
|
|
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
|
|
|
|
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
|
end = cu_tokens_across_sp_cpu[rank]
|
|
buffer[start:end, :].copy_(x)
|
|
for idx in range(world_size):
|
|
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
|
end = cu_tokens_across_sp_cpu[idx]
|
|
get_ep_group().broadcast(buffer[start:end, :], idx)
|
|
|
|
return buffer
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
is_sequence_parallel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
|
dp_metadata = get_forward_context().dp_metadata
|
|
assert dp_metadata is not None
|
|
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
|
|
|
hidden_states = self.naive_multicast(
|
|
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
|
)
|
|
router_logits = self.naive_multicast(
|
|
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
|
)
|
|
return hidden_states, router_logits
|
|
|
|
def combine(
|
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
|
) -> torch.Tensor:
|
|
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
|
|
|
dp_metadata = get_forward_context().dp_metadata
|
|
assert dp_metadata is not None
|
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
|
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
|
|
|
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
|
end = cu_tokens_across_sp_cpu[ep_rank]
|
|
|
|
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
|
hidden_states = all_hidden_states[start:end, :]
|
|
return hidden_states
|
|
|
|
def destroy(self):
|
|
pass
|
|
|
|
|
|
class AgRsAll2AllManager(All2AllManagerBase):
|
|
"""
|
|
An implementation of all2all communication based on
|
|
all-gather (dispatch) and reduce-scatter (combine).
|
|
"""
|
|
|
|
def __init__(self, cpu_group):
|
|
super().__init__(cpu_group)
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
is_sequence_parallel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Gather hidden_states and router_logits from all dp ranks.
|
|
"""
|
|
dp_metadata = get_forward_context().dp_metadata
|
|
assert dp_metadata is not None
|
|
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
|
assert sizes is not None
|
|
|
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
|
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
|
hidden_states, router_logits = dist_group.all_gatherv(
|
|
[hidden_states, router_logits],
|
|
dim=0,
|
|
sizes=sizes,
|
|
)
|
|
return hidden_states, router_logits
|
|
|
|
def combine(
|
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
|
) -> torch.Tensor:
|
|
"""
|
|
Reduce-scatter hidden_states across all dp ranks.
|
|
"""
|
|
dp_metadata = get_forward_context().dp_metadata
|
|
assert dp_metadata is not None
|
|
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
|
assert sizes is not None
|
|
|
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
|
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
|
return hidden_states
|
|
|
|
def destroy(self):
|
|
pass
|
|
|
|
|
|
class PPLXAll2AllManager(All2AllManagerBase):
|
|
"""
|
|
All2All communication based on PPLX kernels.
|
|
"""
|
|
|
|
def __init__(self, cpu_group):
|
|
assert has_pplx(), (
|
|
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
|
" to install pplx_kernels."
|
|
)
|
|
super().__init__(cpu_group)
|
|
|
|
if self.internode:
|
|
# inter-node communication needs nvshmem,
|
|
# intra-node communication uses p2p mapping directly
|
|
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
|
nvshmem_alloc_empty_unique_id,
|
|
nvshmem_get_unique_id,
|
|
nvshmem_init,
|
|
)
|
|
|
|
logger.debug(
|
|
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
|
self.rank,
|
|
self.world_size,
|
|
)
|
|
uid = (
|
|
nvshmem_get_unique_id()
|
|
if self.rank == 0
|
|
else nvshmem_alloc_empty_unique_id()
|
|
)
|
|
dist.broadcast(
|
|
uid,
|
|
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
|
group=self.cpu_group,
|
|
)
|
|
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
|
nvshmem_init(uid, self.rank, self.world_size)
|
|
|
|
self.handle_cache = Cache()
|
|
|
|
def get_handle(self, kwargs):
|
|
import pplx_kernels as pplx # type: ignore[import-not-found]
|
|
|
|
return self.handle_cache.get_or_create(
|
|
kwargs,
|
|
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
|
)
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
is_sequence_parallel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def combine(
|
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def destroy(self):
|
|
with self.handle_cache._lock:
|
|
for _, handle in self.handle_cache._cache.items():
|
|
handle.destroy()
|
|
|
|
if self.internode:
|
|
from pplx_kernels.nvshmem import (
|
|
nvshmem_finalize, # type: ignore[import-not-found]
|
|
)
|
|
|
|
logger.debug("PPLX NVSHMEM finalize")
|
|
nvshmem_finalize()
|
|
|
|
|
|
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
|
"""
|
|
All2All communication based on DeepEP High-Throughput kernels.
|
|
"""
|
|
|
|
def __init__(self, cpu_group):
|
|
assert has_deep_ep(), (
|
|
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
|
" to install DeepEP kernels."
|
|
) # noqa
|
|
super().__init__(cpu_group)
|
|
self.handle_cache = Cache()
|
|
|
|
# This is the DeepEP default. Stick to it till we can establish
|
|
# reasonable defaults based on profiling.
|
|
self.num_sms = 20
|
|
|
|
def get_handle(self, kwargs):
|
|
raise NotImplementedError
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
is_sequence_parallel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def combine(
|
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def destroy(self):
|
|
pass
|
|
|
|
|
|
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
|
"""
|
|
All2All communication based on DeepEP High-Throughput kernels.
|
|
"""
|
|
|
|
def __init__(self, cpu_group):
|
|
super().__init__(cpu_group)
|
|
|
|
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
|
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
|
num_rdma_bytes = None
|
|
num_qps_per_rank = None
|
|
|
|
if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
|
|
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
|
num_qps_per_rank = self.num_sms // 2
|
|
else:
|
|
num_rdma_bytes = 0
|
|
num_qps_per_rank = 1
|
|
|
|
assert num_rdma_bytes is not None
|
|
assert num_qps_per_rank is not None
|
|
return dict(
|
|
group=self.cpu_group,
|
|
num_nvl_bytes=num_nvl_bytes,
|
|
num_rdma_bytes=num_rdma_bytes,
|
|
low_latency_mode=False,
|
|
num_qps_per_rank=num_qps_per_rank,
|
|
)
|
|
|
|
def get_handle(self, kwargs):
|
|
assert len(kwargs) == 0, (
|
|
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
|
"args are computed in the Manager itself."
|
|
)
|
|
|
|
import deep_ep # type: ignore[import-not-found]
|
|
|
|
buffer_kwargs = self._make_all2all_kwargs()
|
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
|
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
|
buffer_kwargs, deep_ep.Buffer
|
|
)
|
|
return handle
|
|
|
|
def set_num_sms(self, num_sms: int):
|
|
import deep_ep # type: ignore[import-not-found]
|
|
|
|
# Right now the buffers are sized for only what the kernels were
|
|
# created with. So we can only reduce the number of SMS used
|
|
# but not increase it.
|
|
if num_sms > self.num_sms:
|
|
num_sms = self.num_sms
|
|
deep_ep.Buffer.set_num_sms(num_sms)
|
|
|
|
|
|
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|
"""
|
|
All2All communication based on DeepEP Low-Latency kernels.
|
|
"""
|
|
|
|
def __init__(self, cpu_group):
|
|
super().__init__(cpu_group)
|
|
|
|
def _make_all2all_kwargs(
|
|
self,
|
|
max_num_tokens_per_dp_rank: int,
|
|
token_hidden_size: int,
|
|
num_ep_ranks: int,
|
|
num_global_experts: int,
|
|
num_local_experts: int,
|
|
) -> dict[Any, Any]:
|
|
"""
|
|
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
|
|
can dispatch all the ranks must hold the same value.
|
|
token_hidden_size: the hidden dimension of each token.
|
|
num_ep_ranks: the number of EP group ranks.
|
|
num_global_experts: Number of experts in the model.
|
|
num_local_experts: Number of experts in an EP rank.
|
|
"""
|
|
import deep_ep # type: ignore[import-not-found]
|
|
|
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
|
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
|
num_qps_per_rank = num_local_experts
|
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
|
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
|
hidden=token_hidden_size,
|
|
num_ranks=num_ep_ranks,
|
|
num_experts=num_global_experts,
|
|
)
|
|
|
|
assert num_rdma_bytes is not None
|
|
return dict(
|
|
group=self.cpu_group,
|
|
num_nvl_bytes=num_nvl_bytes,
|
|
num_rdma_bytes=num_rdma_bytes,
|
|
low_latency_mode=True,
|
|
num_qps_per_rank=num_qps_per_rank,
|
|
allow_nvlink_for_low_latency_mode=True,
|
|
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
|
)
|
|
|
|
def get_handle(self, kwargs):
|
|
"""
|
|
The kwargs for DeepEPLLAll2AllManager is dictated by
|
|
_make_all2all_kwargs.
|
|
"""
|
|
import deep_ep # type: ignore[import-not-found]
|
|
|
|
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
|
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
|
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
|
buffer_kwargs, deep_ep.Buffer
|
|
)
|
|
return handle
|
|
|
|
# DeepEP LL uses RDMA so no SMs are used for communication
|
|
def max_sms_used(self) -> int | None:
|
|
return 0
|
|
|
|
|
|
class FlashInferAllToAllManager(All2AllManagerBase):
|
|
"""
|
|
All2All communication based on flashinfer kernels.
|
|
"""
|
|
|
|
# This type lint could be removed after all of the work in
|
|
# https://github.com/vllm-project/vllm/issues/26533 done.
|
|
rank: int
|
|
world_size: int
|
|
|
|
def __init__(self, cpu_group):
|
|
assert has_flashinfer_all2all(), (
|
|
"flashinfer all2all module not found. Please install/check flashinfer"
|
|
) # noqa
|
|
super().__init__(cpu_group)
|
|
logger.debug(
|
|
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
|
self.rank,
|
|
self.world_size,
|
|
)
|
|
self.initialized = False
|
|
self.alltoall_info = None
|
|
|
|
def initialize(
|
|
self,
|
|
world_size: int,
|
|
rank: int,
|
|
gpus_per_node: int,
|
|
):
|
|
"""Initialize workspace"""
|
|
if self.initialized:
|
|
return
|
|
|
|
self.cleanup()
|
|
logger.debug("making map: rank=%d, world size=%d", rank, world_size)
|
|
self.mapping = Mapping(
|
|
world_size,
|
|
rank,
|
|
gpus_per_node,
|
|
tp_size=world_size,
|
|
)
|
|
|
|
from vllm.distributed.device_communicators.mnnvl_compat import (
|
|
CustomCommunicator,
|
|
)
|
|
|
|
dp_config = MnnvlConfig(
|
|
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
|
fabric_page_size=1 << 29, # 512MB
|
|
allocation_granularity=0, # Auto-detect
|
|
)
|
|
|
|
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
|
|
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
|
|
self.mapping, dp_config
|
|
)
|
|
|
|
self.world_size = world_size
|
|
self.rank = rank
|
|
self.gpus_per_node = gpus_per_node
|
|
self.initialized = True
|
|
|
|
logger.info(
|
|
"FlashInfer All2All initialized for rank %s, size %s", rank, world_size
|
|
)
|
|
|
|
def ensure_alltoall_workspace_initialized(self):
|
|
"""Ensure workspace is initialized"""
|
|
if not has_flashinfer_all2all():
|
|
return False
|
|
|
|
if self.world_size <= 1:
|
|
return False
|
|
|
|
if not self.initialized:
|
|
self.initialize(
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
gpus_per_node=torch.cuda.device_count,
|
|
)
|
|
return self.initialized
|
|
|
|
def get_handle(self, kwargs):
|
|
return self
|
|
|
|
def cleanup(self):
|
|
"""Clean up workspace"""
|
|
if (
|
|
self.initialized
|
|
and self.workspace_tensor is not None
|
|
and self.prepare_workspace_tensor is not None
|
|
):
|
|
try:
|
|
del self.workspace_tensor
|
|
del self.prepare_workspace_tensor
|
|
except Exception as e:
|
|
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
|
|
finally:
|
|
self.workspace_tensor = None
|
|
self.prepare_workspace_tensor = None
|
|
self.mapping = None
|
|
self.initialized = False
|