mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 00:37:03 +08:00
Refactor pplx init logic to make it modular (prepare for deepep) (#18200)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
022d8abe29
commit
6a7988c55b
@ -1,44 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import importlib.util
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
else:
|
||||
FusedMoE = None
|
||||
|
||||
|
||||
class All2AllBase:
|
||||
|
||||
def __init__(self, cpu_group, model):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_ep_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
self.ep_group = get_ep_group()
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
|
||||
self.internode = not self.intranode
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class NaiveAll2All(All2AllBase):
|
||||
class NaiveAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
A naive implementation of all2all communication.
|
||||
It uses all-reduce under the hood, which is not
|
||||
@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase):
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group, model):
|
||||
super().__init__(cpu_group, model)
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
@ -91,3 +71,56 @@ class NaiveAll2All(All2AllBase):
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
dist.broadcast(uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
@ -1,11 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import threading
|
||||
from typing import Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def get_or_create(self, kwargs, func):
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
instance = func(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
|
||||
self.internode = not self.intranode
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
# based on the kwargs.
|
||||
# different layers can have different configs,
|
||||
# e.g. one layer has hidden size 1024, another has 2048.
|
||||
# usually the underlying implementation caches the handle
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||
self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
|
||||
self.use_all2all = "ep" in unique_name and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
pass
|
||||
if not self.use_all2all:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module.moe_config,
|
||||
module.quant_config)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
|
||||
@ -6,10 +6,12 @@ import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .all2all import All2AllBase
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
@ -31,8 +33,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_all2all = "ep" in unique_name
|
||||
self.all2all_impl: Optional[All2AllBase] = None
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
@ -56,6 +56,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
elif all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
logger.info("Using PPLX all2all manager.")
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# always try custom allreduce first,
|
||||
# and then pynccl.
|
||||
@ -136,31 +149,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.all2all_impl is not None:
|
||||
self.all2all_impl.destroy()
|
||||
self.all2all_impl = None
|
||||
|
||||
def prepare_communication_buffer_for_model(self,
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.use_all2all:
|
||||
return
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2All
|
||||
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_impl is not None
|
||||
hidden_states, router_logits = self.all2all_impl.dispatch(
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert self.all2all_impl is not None
|
||||
hidden_states = self.all2all_impl.combine(hidden_states)
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -23,7 +23,6 @@ If you only need to use the distributed environment without model/pipeline
|
||||
"""
|
||||
import contextlib
|
||||
import gc
|
||||
import importlib.util
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
@ -43,7 +42,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||
run_once, supports_custom_op)
|
||||
supports_custom_op)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -791,10 +790,14 @@ class GroupCoordinator:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(hidden_states,
|
||||
router_logits)
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states) -> torch.Tensor:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.combine(hidden_states)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
_WORLD: Optional[GroupCoordinator] = None
|
||||
@ -959,49 +962,9 @@ def init_distributed_environment(
|
||||
"world group already initialized with a different world size")
|
||||
|
||||
|
||||
PPLX_DID_INIT: bool = False
|
||||
|
||||
|
||||
@run_once
|
||||
def pplx_init(rank, world_size):
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
|
||||
if has_pplx and world_size > 1:
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id, nvshmem_init)
|
||||
try:
|
||||
global PPLX_DID_INIT
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
|
||||
"world size=%d", rank, world_size)
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid_gpu = uid.cuda()
|
||||
get_world_group().broadcast(uid_gpu, src=0)
|
||||
uid = uid_gpu.to(device='cpu')
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, rank, world_size)
|
||||
PPLX_DID_INIT = True
|
||||
except Exception as ex:
|
||||
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
|
||||
|
||||
|
||||
@run_once
|
||||
def pplx_finalize():
|
||||
global PPLX_DID_INIT
|
||||
if PPLX_DID_INIT:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
_all_to_all_cache)
|
||||
_all_to_all_cache.destroy()
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
enable_expert_parallel: bool = False,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -1104,14 +1067,10 @@ def initialize_model_parallel(
|
||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
|
||||
_EP.rank_in_group)
|
||||
|
||||
if enable_expert_parallel:
|
||||
pplx_init(rank, world_size)
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
enable_expert_parallel: bool = False,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
@ -1122,8 +1081,7 @@ def ensure_model_parallel_initialized(
|
||||
get_world_group().device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size,
|
||||
enable_expert_parallel, backend)
|
||||
pipeline_model_parallel_size, backend)
|
||||
return
|
||||
|
||||
assert (
|
||||
@ -1202,8 +1160,6 @@ def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
global _TP
|
||||
|
||||
pplx_finalize()
|
||||
|
||||
if _TP:
|
||||
_TP.destroy()
|
||||
_TP = None
|
||||
|
||||
@ -809,6 +809,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
|
||||
|
||||
# all2all backend for vllm's expert parallel communication
|
||||
# Available options:
|
||||
# - "naive": naive all2all implementation using all-reduce
|
||||
# - "pplx": use pplx kernels
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||
}
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -73,7 +71,8 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep and has_pplx
|
||||
return self.dp_size > 1 and self.use_ep and \
|
||||
envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
@ -196,6 +195,8 @@ class MoEConfig:
|
||||
# TODO: add more quantization params, blocked, per-token, etc.
|
||||
block_size: int = 128
|
||||
|
||||
max_num_tokens: int = MOE_DP_CHUNK_SIZE
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@ -244,13 +245,59 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
return False
|
||||
def init_prepare_finalize(self, moe: MoEConfig,
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize = None
|
||||
if moe.use_pplx_kernels:
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=all2all_manager.rank,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
|
||||
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
||||
torch.float32.itemsize)),
|
||||
group_name=all2all_manager.cpu_group.group_name,
|
||||
)
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
rank=all2all_manager.rank,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
quant_dtype=moe.in_dtype,
|
||||
)
|
||||
|
||||
if prepare_finalize is not None:
|
||||
experts = self.select_gemm_impl(prepare_finalize)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
"Subclass must select appropriate gemm implementation"
|
||||
" based on the prepare_finalize")
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
@ -274,53 +321,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AllToAllCache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def destroy(self):
|
||||
with self._lock:
|
||||
# TODO: can we do del self._cache?
|
||||
for _, a2a in self._cache.items():
|
||||
a2a.destroy()
|
||||
|
||||
def get_or_create(self, **kwargs):
|
||||
assert has_pplx
|
||||
import pplx_kernels as pplx
|
||||
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
# TODO (varun): Add support to switch to intranode
|
||||
# when all communications are within the same
|
||||
# node.
|
||||
logger.debug("Create AllToAll %s", kwargs)
|
||||
instance = pplx.AllToAll.internode(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
# Global singleton
|
||||
_all_to_all_cache = AllToAllCache()
|
||||
|
||||
|
||||
# Factory function as a cleaner interface
|
||||
def get_all_to_all(**kwargs):
|
||||
return _all_to_all_cache.get_or_create(**kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, moe: MoEConfig):
|
||||
super().__init__()
|
||||
self.fused_experts = fused_experts
|
||||
self.fused_experts = fused_experts # type: ignore
|
||||
self.moe = moe
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
@ -330,6 +337,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
|
||||
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
return experts
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -429,47 +472,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -679,45 +681,6 @@ def determine_expert_map(
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
|
||||
def _construct_prepare_finalize(
|
||||
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
max_num_tokens = MOE_DP_CHUNK_SIZE
|
||||
world_size = moe.ep_size
|
||||
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
|
||||
rank = moe.ep_rank
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
logger.debug("using PplxPrepareAndFinalize")
|
||||
|
||||
all_to_all = get_all_to_all(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) //
|
||||
moe.block_size * torch.float32.itemsize)))
|
||||
|
||||
return PplxPrepareAndFinalize(
|
||||
all_to_all,
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
dp_size=dp_size,
|
||||
quant_dtype=moe.in_dtype,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
@ -831,7 +794,10 @@ class FusedMoE(torch.nn.Module):
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
@ -839,25 +805,13 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if quant_config is None:
|
||||
quant_method = UnquantizedFusedMoEMethod(moe)
|
||||
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
|
||||
else:
|
||||
quant_method = quant_config.get_quant_method(self, prefix)
|
||||
# No pplx for quantized types yet.
|
||||
prepare_finalize = None
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if prepare_finalize is not None:
|
||||
world_size = moe.ep_size
|
||||
dp_size = int(moe.ep_size // moe.dp_size)
|
||||
success = self.quant_method.set_prepare_finalize(
|
||||
dp_size, world_size, prepare_finalize)
|
||||
if not success:
|
||||
logger.warning("DP+EP not supported for %s.",
|
||||
type(self.quant_method))
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
|
||||
@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
|
||||
# Note use: layer.get_all_to_all() to get an AllToAll instance
|
||||
# The max_num_tokens, world_size and dp_size must be the same
|
||||
# as the ones used to create the AllToAll.
|
||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
@ -10,7 +10,6 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -461,7 +460,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
self.fused_experts = functools.partial(
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
@ -791,17 +790,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
def select_gemm_impl(self, prepare_finalize):
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return False
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
@ -809,12 +803,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
self.fused_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
|
||||
"currently not supported with CUDA Graphs.")
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
compilation_config.use_cudagraph = False
|
||||
# FIXME: inductor breaks cudagraph (from @bnell)
|
||||
compilation_config.use_inductor = False
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -348,8 +348,7 @@ def init_worker_distributed_environment(
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment(
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size in bytes of a single KV cache block.
|
||||
|
||||
@ -415,8 +415,7 @@ def init_worker_distributed_environment(
|
||||
backend='hccl')
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
@ -442,8 +441,7 @@ def init_worker_distributed_environment(
|
||||
torch.distributed.all_reduce(dummy_tensor_hpu)
|
||||
assert dummy_tensor_hpu.item() == parallel_config.world_size
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
|
||||
|
||||
@ -76,8 +76,7 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.enable_expert_parallel)
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
# Device initialization should happen after initializing the distributed
|
||||
# runtime.
|
||||
|
||||
@ -529,8 +529,7 @@ def init_worker_distributed_environment(
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
@ -175,8 +175,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.pipeline_parallel_size)
|
||||
# global all_reduce needed for overall oneccl warm up
|
||||
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user