Refactor pplx init logic to make it modular (prepare for deepep) (#18200)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-05-23 23:43:43 +08:00 committed by GitHub
parent 022d8abe29
commit 6a7988c55b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 300 additions and 287 deletions

View File

@ -1,44 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from .base_device_communicator import All2AllManagerBase, Cache
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None
class All2AllBase:
def __init__(self, cpu_group, model):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_ep_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class NaiveAll2All(All2AllBase):
class NaiveAll2AllManager(All2AllManagerBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase):
debugging.
"""
def __init__(self, cpu_group, model):
super().__init__(cpu_group, model)
def __init__(self, cpu_group):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
@ -91,3 +71,56 @@ class NaiveAll2All(All2AllBase):
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()

View File

@ -1,11 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional
from weakref import WeakValueDictionary
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance
class All2AllManagerBase:
def __init__(self, cpu_group):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
if not self.use_all2all:
return
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)
def dispatch(
self, hidden_states: torch.Tensor,

View File

@ -6,10 +6,12 @@ import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .all2all import All2AllBase
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CudaCommunicator(DeviceCommunicatorBase):
@ -31,8 +33,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
@ -56,6 +56,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
@ -136,31 +149,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_impl is not None:
self.all2all_impl.destroy()
self.all2all_impl = None
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None
hidden_states = self.all2all_impl.combine(hidden_states)
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states

View File

@ -23,7 +23,6 @@ If you only need to use the distributed environment without model/pipeline
"""
import contextlib
import gc
import importlib.util
import pickle
import weakref
from collections import namedtuple
@ -43,7 +42,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
run_once, supports_custom_op)
supports_custom_op)
@dataclass
@ -791,10 +790,14 @@ class GroupCoordinator:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
else:
return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
else:
return hidden_states
_WORLD: Optional[GroupCoordinator] = None
@ -959,49 +962,9 @@ def init_distributed_environment(
"world group already initialized with a different world size")
PPLX_DID_INIT: bool = False
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""
@ -1104,14 +1067,10 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)
if enable_expert_parallel:
pplx_init(rank, world_size)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
@ -1122,8 +1081,7 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size,
enable_expert_parallel, backend)
pipeline_model_parallel_size, backend)
return
assert (
@ -1202,8 +1160,6 @@ def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
pplx_finalize()
if _TP:
_TP.destroy()
_TP = None

View File

@ -809,6 +809,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
}

View File

@ -1,12 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import importlib
import threading
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
from weakref import WeakValueDictionary
import torch
import torch.nn.functional as F
@ -73,7 +71,8 @@ class FusedMoEParallelConfig:
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx
return self.dp_size > 1 and self.use_ep and \
envs.VLLM_ALL2ALL_BACKEND == "pplx"
@staticmethod
def make(tp_size_: int, dp_size_: int,
@ -196,6 +195,8 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
max_num_tokens: int = MOE_DP_CHUNK_SIZE
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@ -244,13 +245,59 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
return False
def init_prepare_finalize(self, moe: MoEConfig,
quant_config: Optional[QuantizationConfig]):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize = None
if moe.use_pplx_kernels:
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
group_name=all2all_manager.cpu_group.group_name,
)
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
world_size=all2all_manager.world_size,
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
)
if prepare_finalize is not None:
experts = self.select_gemm_impl(prepare_finalize)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
"Subclass must select appropriate gemm implementation"
" based on the prepare_finalize")
@abstractmethod
def apply(
@ -274,53 +321,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
class AllToAllCache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def destroy(self):
with self._lock:
# TODO: can we do del self._cache?
for _, a2a in self._cache.items():
a2a.destroy()
def get_or_create(self, **kwargs):
assert has_pplx
import pplx_kernels as pplx
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
# TODO (varun): Add support to switch to intranode
# when all communications are within the same
# node.
logger.debug("Create AllToAll %s", kwargs)
instance = pplx.AllToAll.internode(**kwargs)
self._cache[key] = instance
return instance
# Global singleton
_all_to_all_cache = AllToAllCache()
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: MoEConfig):
super().__init__()
self.fused_experts = fused_experts
self.fused_experts = fused_experts # type: ignore
self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
@ -330,6 +337,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
assert self.fused_experts == fused_experts
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
return experts
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -429,47 +472,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda(
self,
layer: torch.nn.Module,
@ -679,45 +681,6 @@ def determine_expert_map(
return (local_num_experts, expert_map)
def _construct_prepare_finalize(
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
) -> Optional[FusedMoEPrepareAndFinalize]:
max_num_tokens = MOE_DP_CHUNK_SIZE
world_size = moe.ep_size
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
rank = moe.ep_rank
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
return PplxPrepareAndFinalize(
all_to_all,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,
dp_size=dp_size,
quant_dtype=moe.in_dtype,
)
return None
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
@ -831,7 +794,10 @@ class FusedMoE(torch.nn.Module):
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
@ -839,25 +805,13 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
quant_method = UnquantizedFusedMoEMethod(moe)
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
else:
quant_method = quant_config.get_quant_method(self, prefix)
# No pplx for quantized types yet.
prepare_finalize = None
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
if prepare_finalize is not None:
world_size = moe.ep_size
dp_size = int(moe.ep_size // moe.dp_size)
success = self.quant_method.set_prepare_finalize(
dp_size, world_size, prepare_finalize)
if not success:
logger.warning("DP+EP not supported for %s.",
type(self.quant_method))
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,

View File

@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
# Note use: layer.get_all_to_all() to get an AllToAll instance
# The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll.
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):

View File

@ -10,7 +10,6 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@ -461,7 +460,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial(
self.fused_experts = functools.partial( # type: ignore
fused_experts,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
@ -791,17 +790,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled:
return False
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
@ -809,12 +803,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=self.allow_deep_gemm,
)
self.fused_experts = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
return experts
def apply(
self,

View File

@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
# FIXME: inductor breaks cudagraph (from @bnell)
compilation_config.use_inductor = False
@classmethod

View File

@ -348,8 +348,7 @@ def init_worker_distributed_environment(
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)

View File

@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment(
backend="gloo",
)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
try:

View File

@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.

View File

@ -415,8 +415,7 @@ def init_worker_distributed_environment(
backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
@ -442,8 +441,7 @@ def init_worker_distributed_environment(
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,

View File

@ -76,8 +76,7 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
)
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size,
self.parallel_config.enable_expert_parallel)
self.parallel_config.pipeline_parallel_size)
# Device initialization should happen after initializing the distributed
# runtime.

View File

@ -529,8 +529,7 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)

View File

@ -175,8 +175,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())