Elastic Expert Parallel Initial Support (#20775)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao 2025-07-18 17:46:09 -07:00 committed by GitHub
parent 5782581acf
commit 217937221b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1659 additions and 68 deletions

View File

@ -0,0 +1,57 @@
#!/bin/bash
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"
LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0"
HOST="localhost"
PORT=8006
NUM_PROMPTS=20
REQUEST_RATE=5
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME="$2"
shift 2
;;
--local-model)
MODEL_NAME=$LOCAL_MODEL_PATH
shift
;;
--host)
HOST="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--num-prompts)
NUM_PROMPTS="$2"
shift 2
;;
--request-rate)
REQUEST_RATE="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model MODEL_NAME Set model name or path (default: deepseek-ai/DeepSeek-V2-Lite)"
echo " --local-model Use local model path (convenience option)"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use -h or --help for usage information"
exit 1
;;
esac
done
vllm bench serve \
--model $MODEL_NAME \
--host $HOST \
--port $PORT \
--num-prompts $NUM_PROMPTS \
--request-rate $REQUEST_RATE

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import sys
import requests
def scale(host, port, new_dp_size):
url = f"http://{host}:{port}/scale_elastic_ep"
payload = {"new_data_parallel_size": new_dp_size}
headers = {"Content-Type": "application/json"}
print(f"Sending scale request to {url}")
print(f"Payload: {json.dumps(payload, indent=2)}")
try:
response = requests.post(url, json=payload, headers=headers, timeout=300)
print(f"Status Code: {response.status_code}")
print(f"Response: {response.text}")
if response.status_code == 200:
print("Scale up/down request successful!")
return True
else:
print("Scale up/down request failed!")
return False
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
return False
def main():
parser = argparse.ArgumentParser(description="Test scale up/down functionality")
parser.add_argument("--host", default="localhost", help="API server host")
parser.add_argument("--port", type=int, default=8006, help="API server port")
parser.add_argument(
"--new-dp-size", type=int, default=2, help="New data parallel size"
)
args = parser.parse_args()
success = scale(args.host, args.port, args.new_dp_size)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,72 @@
#!/bin/bash
HOST="0.0.0.0"
PORT=8006
DATA_PARALLEL_SIZE=4
REDUNDANT_EXPERTS=0
LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0"
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"
while [[ $# -gt 0 ]]; do
case $1 in
--dp)
DATA_PARALLEL_SIZE="$2"
shift 2
;;
--re)
REDUNDANT_EXPERTS="$2"
shift 2
;;
--host)
HOST="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--model)
MODEL_NAME="$2"
shift 2
;;
--local-model)
MODEL_NAME=$LOCAL_MODEL_PATH
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --dp SIZE Set data parallel size (default: 4)"
echo " --re SIZE Set redundant experts (default: 0)"
echo " --host HOST Set host address (default: 0.0.0.0)"
echo " --port PORT Set port number (default: 8006)"
echo " --model MODEL_NAME Set model name or path"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use -h or --help for usage information"
exit 1
;;
esac
done
echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"
export RAY_DEDUP_LOGS=0
export VLLM_USE_V1=1
export VLLM_ALL2ALL_BACKEND="pplx"
export VLLM_USE_DEEP_GEMM=1
vllm serve $MODEL_NAME \
--data-parallel-size $DATA_PARALLEL_SIZE \
--data-parallel-size-local $DATA_PARALLEL_SIZE \
--data-parallel-backend ray \
--enforce-eager \
--enable-expert-parallel \
--enable-eplb \
--num-redundant-experts $REDUNDANT_EXPERTS \
--trust-remote-code \
--host $HOST \
--port $PORT

View File

@ -0,0 +1,92 @@
From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001
From: Yongji Wu <wuyongji317@gmail.com>
Date: Tue, 20 May 2025 13:41:12 -0700
Subject: [PATCH] fix reinit issues due to states not cleaned up
fix double free
---
src/host/init/init.cu | 10 ++++++++++
.../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++
src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++
3 files changed, 30 insertions(+)
diff --git a/src/host/init/init.cu b/src/host/init/init.cu
index b1c5dbf..1fecb4b 100644
--- a/src/host/init/init.cu
+++ b/src/host/init/init.cu
@@ -43,6 +43,8 @@
#include "internal/host/nvshmemi_types.h"
#include "internal/host/shared_memory.h"
#include "internal/host/nvshmemi_symmetric_heap.hpp"
+// eep-dev
+#include "internal/host/nvshmemi_mem_transport.hpp"
extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d;
static std::map<void *, int> registered_device_states;
@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) {
/* Multi-init Multi-fini*/
nvshmemi_state = NULL;
nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0;
+
+ // eep-dev
+ nvshmemi_mem_p2p_transport::destroy_instance();
+ nvshmemi_mem_remote_transport::destroy_instance();
+ free(nvshmemi_default_session);
+ nvshmemi_default_session = nullptr;
+ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false;
+
nvshmemi_is_device_state_ready = false;
} else
nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle);
diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp
index 2495844..e4f408a 100644
--- a/src/include/internal/host/nvshmemi_mem_transport.hpp
+++ b/src/include/internal/host/nvshmemi_mem_transport.hpp
@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final {
return p2p_objref_;
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (p2p_objref_ != nullptr) {
+ delete p2p_objref_;
+ p2p_objref_ = nullptr;
+ }
+ }
void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj);
@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final {
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (remote_objref_ != nullptr) {
+ delete remote_objref_;
+ remote_objref_ = nullptr;
+ }
+ }
+
int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size);
/* On-demand registration and release of memory */
int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx,
diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp
index a1fa748..788fa96 100644
--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp
+++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp
@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi
// Discover the network for bootstrap, if not done previously.
// This code needs to be stateful to be able to be called multiple times by the caller
BOOTSTRAP_CHECK(bootstrap_net_init());
+ // eep-dev
+ if (handle->pre_init_ops != nullptr) {
+ BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
+ handle->pre_init_ops = nullptr;
+ }
if (handle->pre_init_ops == nullptr) {
BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1);
handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id;
--
2.43.0

View File

@ -0,0 +1,86 @@
#!/bin/bash
set -ex
# Default workspace directory
WORKSPACE=$(pwd)/eep_kernels_workspace
INSTALL_NVSHMEM=true
# Parse command line arguments
while getopts "w:n" opt; do
case $opt in
w)
WORKSPACE="$OPTARG"
;;
n)
INSTALL_NVSHMEM=false
;;
\?)
echo "Invalid option: -$OPTARG" >&2
exit 1
;;
esac
done
if [ ! -d "$WORKSPACE" ]; then
mkdir -p $WORKSPACE
fi
# install dependencies if not installed
pip3 install cmake torch ninja
# build nvshmem
pushd $WORKSPACE
# Reset NVSHMEM build if requested
if [ "$INSTALL_NVSHMEM" = true ]; then
mkdir -p nvshmem_src
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1
pushd nvshmem_src
wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
git init
git apply -vvv nvshmem.patch
git apply --reject --whitespace=fix ../../eep_nvshmem.patch
else
pushd nvshmem_src
fi
# assume CUDA_HOME is set correctly
if [ -z "$CUDA_HOME" ]; then
echo "CUDA_HOME is not set, please set it to your CUDA installation directory."
exit 1
fi
# disable all features except IBGDA
export NVSHMEM_IBGDA_SUPPORT=1
export NVSHMEM_SHMEM_SUPPORT=0
export NVSHMEM_UCX_SUPPORT=0
export NVSHMEM_USE_NCCL=0
export NVSHMEM_PMIX_SUPPORT=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
export NVSHMEM_USE_GDRCOPY=0
export NVSHMEM_IBRC_SUPPORT=0
export NVSHMEM_BUILD_TESTS=0
export NVSHMEM_BUILD_EXAMPLES=0
export NVSHMEM_MPI_SUPPORT=0
export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
export NVSHMEM_BUILD_TXZ_PACKAGE=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install
cmake --build $WORKSPACE/nvshmem_build/ --target install
popd
export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
# build and install pplx, require pytorch installed
pushd $WORKSPACE
git clone https://github.com/ppl-ai/pplx-kernels
cd pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v

View File

@ -2008,6 +2008,19 @@ class ParallelConfig:
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished
@staticmethod
def sync_kv_cache_memory_size(dp_group: "ProcessGroup",
kv_cache_memory: int) -> int:
if kv_cache_memory == -1:
kv_cache_memory = torch.iinfo(torch.int64).max
tensor = torch.tensor([kv_cache_memory],
dtype=torch.int64,
device="cpu")
# we cannot use broadcast for stateless dp group since it depends
# on global rank
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
return tensor.item()
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs

View File

@ -29,12 +29,15 @@ physical experts.
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch.distributed import all_gather, all_reduce
from torch.distributed import ProcessGroup, all_gather, all_reduce
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
in_the_same_node_as)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
@ -172,6 +175,9 @@ class EplbState:
model: MixtureOfExperts,
device: torch.device,
parallel_config: ParallelConfig,
global_expert_load: Optional[torch.Tensor] = None,
old_global_expert_indices: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None,
) -> "EplbState":
"""
Build the initial EPLB state.
@ -185,8 +191,16 @@ class EplbState:
physical_to_logical_map_list,
device=device,
)
# Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now.
# TODO(rui): make this configurable
MAX_EXPERT_REDUNDANCY = 1023
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
f"num_redundant_experts {model.num_redundant_experts} "
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
logical_to_physical_map = torch.full(
(model.num_logical_experts, model.num_redundant_experts + 1),
(model.num_logical_experts, max_slots_per_logical_expert),
-1,
device=device,
)
@ -235,11 +249,63 @@ class EplbState:
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}")
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = (rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
))
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
expert_rearrangement_step = 0
return cls(
physical_to_logical_map,
@ -337,7 +403,10 @@ class EplbState:
def rearrange(self,
model: MixtureOfExperts,
is_profile: bool = False) -> None:
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None) -> None:
"""
Rearrange the experts according to the current load.
"""
@ -353,42 +422,79 @@ class EplbState:
logger.info("Rearranging experts %s...",
"(profile)" if is_profile else "")
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
if global_expert_load is None:
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
metadata = torch.tensor(
[
model.num_moe_layers, model.num_logical_experts,
self.physical_to_logical_map.shape[1]
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(metadata,
group=get_ep_group().cpu_group,
group_src=0)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = self.physical_to_logical_map
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group,
group_src=0)
return global_expert_load_window
else:
assert execute_shuffle
global_expert_load_window = global_expert_load
# TODO(bowen): Treat differently for prefill and decode nodes
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
num_gpus = sum(new_rank != -1
for new_rank in rank_mapping.values())
num_replicas = num_replicas // ep_group.size(
) * num_gpus # handle num replicas change
else:
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
self.num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
@ -414,10 +520,24 @@ class EplbState:
model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
if self.physical_to_logical_map.shape[
1] != new_physical_to_logical_map.shape[1]:
self.physical_to_logical_map = new_physical_to_logical_map.to(
self.physical_to_logical_map.device)
else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0,
self.logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
self.logical_replica_count.copy_(new_logical_replica_count)
@ -430,3 +550,69 @@ class EplbState:
" (profile) " if is_profile else " ",
time_end - time_start,
)
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata,
group=ep_group.cpu_group,
group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist())
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group.device_group,
group_src=0)
return global_expert_load, old_global_expert_indices
def _node_count_with_rank_mapping(
pg: Union[ProcessGroup, StatelessProcessGroup],
rank_mapping: dict[int, int],
) -> int:
if isinstance(pg, ProcessGroup):
world_size = torch.distributed.get_world_size(group=pg)
else:
world_size = pg.world_size
if world_size == 1:
return 1
# Build node assignment map
node_assignment = [0] * world_size # rank -> node_id
next_node_id = 0
for current_rank in range(world_size):
if node_assignment[current_rank] != 0:
continue # Already assigned to a node
assert current_rank in rank_mapping
if rank_mapping[current_rank] == -1:
continue # Pending shutdown
# Assign current rank to a new node
next_node_id += 1
node_assignment[current_rank] = next_node_id
# Find all ranks on the same node as current_rank
same_node_flags = in_the_same_node_as(pg, current_rank)
for other_rank, is_same_node in enumerate(same_node_flags):
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id

View File

@ -8,6 +8,7 @@ This involves the exchange of expert weights between GPUs.
from collections.abc import Iterable, MutableSequence, Sequence
from functools import partial
from typing import Optional
import torch
from torch.distributed import (P2POp, ProcessGroup, all_gather,
@ -127,6 +128,8 @@ def shuffle_layer(
dst_global = local2global(dst)
if is_received_locally[dst]:
continue
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
continue
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights,
@ -139,6 +142,8 @@ def shuffle_layer(
experts_send_loc: dict[int, int] = {}
for src in range(num_local_experts):
expert = old_indices[local2global(src)]
if expert == -1:
continue
if expert in experts_send_loc:
continue
experts_send_loc[expert] = src
@ -181,6 +186,8 @@ def shuffle_layer(
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
@ -227,6 +234,8 @@ def shuffle_layer(
weight[dst].copy_(buffer[dst])
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[src])
@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace(
expert_weights: Sequence[Iterable[torch.Tensor]],
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: Optional[dict[int, int]] = None,
) -> None:
"""
Rearranges the expert weights in place according to the new expert indices.
@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace(
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
rank_mapping: A dictionary mapping old rank to new rank.
"""
if rank_mapping is not None:
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = \
_map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = \
_map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[
1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace(
)
def _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
new_ep_size: int,
) -> torch.Tensor:
"""
Map the old global expert indices to the new global expert indices.
Args:
old_global_expert_indices:
Shape (num_layers, old_ep_size * num_local_physical_experts).
rank_mapping: Mapping from old rank to new rank.
new_ep_size: New expert parallelism size.
Returns:
Mapped expert indices with shape
(num_layers, new_ep_size * num_local_physical_experts).
"""
num_layers, old_num_physical_experts = old_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
num_local_physical_experts = old_num_physical_experts // old_ep_size
new_num_physical_experts = new_ep_size * num_local_physical_experts
# Create mapped tensor with new shape, initialized to -1
mapped_expert_indices = torch.full(
(num_layers, new_num_physical_experts),
fill_value=-1,
dtype=old_global_expert_indices.dtype,
device=old_global_expert_indices.device,
)
# Handle rank mapping (scale up/down with rank changes)
for old_rank in range(old_ep_size):
new_rank = rank_mapping.get(old_rank)
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
# This old rank exists in the new configuration
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
old_global_expert_indices[:, old_start_idx:old_end_idx]
# If new_rank is None or >= new_ep_size, the experts remain -1
# (scale down case)
return mapped_expert_indices
def _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
) -> torch.Tensor:
num_layers, new_num_physical_experts = new_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_local_physical_experts = new_num_physical_experts // new_ep_size
old_num_physical_experts = old_ep_size * num_local_physical_experts
mapped_expert_indices = torch.full(
(num_layers, old_num_physical_experts),
fill_value=-1,
dtype=new_global_expert_indices.dtype,
device=new_global_expert_indices.device,
)
for old_rank in range(old_ep_size):
new_rank = rank_mapping[old_rank]
if new_rank >= 0 and new_rank < new_ep_size:
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
new_global_expert_indices[:, new_start_idx:new_end_idx]
return mapped_expert_indices
__all__ = ["rearrange_expert_weights_inplace"]

View File

@ -324,3 +324,9 @@ class EngineClient(ABC):
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...
async def scale_elastic_ep(self,
new_data_parallel_size: int,
drain_timeout: int = 300) -> None:
"""Scale the engine"""
raise NotImplementedError

View File

@ -1018,6 +1018,73 @@ if envs.VLLM_SERVER_DEV_MODE:
return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/scale_elastic_ep",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {
"model": dict
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.REQUEST_TIMEOUT.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
async def scale_elastic_ep(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400,
detail="Invalid JSON format") from e # noqa: B904
new_data_parallel_size = body.get("new_data_parallel_size")
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
if new_data_parallel_size is None:
raise HTTPException(status_code=400,
detail="new_data_parallel_size is required")
if not isinstance(new_data_parallel_size,
int) or new_data_parallel_size <= 0:
raise HTTPException(
status_code=400,
detail="new_data_parallel_size must be a positive integer")
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
raise HTTPException(status_code=400,
detail="drain_timeout must be a positive integer")
# Set scaling flag to prevent new requests
global _scaling_elastic_ep
_scaling_elastic_ep = True
client = engine_client(raw_request)
try:
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
return JSONResponse({
"message":
f"Scaled to {new_data_parallel_size} "
"data parallel engines",
})
except TimeoutError as e:
raise HTTPException(status_code=408,
detail="Scale failed due to request drain timeout "
f"after {drain_timeout} seconds") from e
except Exception as e:
logger.error("Scale failed: %s", e)
raise HTTPException(status_code=500, detail="Scale failed") from e
finally:
_scaling_elastic_ep = False
@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
@ -1216,6 +1283,41 @@ class XRequestIdMiddleware:
return self.app(scope, receive, send_with_request_id)
# Global variable to track scaling state
_scaling_elastic_ep = False
class ScalingMiddleware:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]:
if scope["type"] != "http":
return self.app(scope, receive, send)
# Check global scaling state
global _scaling_elastic_ep
if _scaling_elastic_ep:
# Return 503 Service Unavailable response
response = JSONResponse(content={
"error":
"The model is currently scaling. Please try again later."
},
status_code=503)
return response(scope, receive, send)
return self.app(scope, receive, send)
def _extract_content_from_chunk(chunk_data: dict) -> str:
"""Extract content from a streaming response chunk."""
try:
@ -1404,6 +1506,9 @@ def build_app(args: Namespace) -> FastAPI:
if args.enable_request_id_headers:
app.add_middleware(XRequestIdMiddleware)
# Add scaling middleware to check for scaling state
app.add_middleware(ScalingMiddleware)
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "

View File

@ -12,6 +12,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -62,6 +63,14 @@ class UniProcExecutor(ExecutorBase):
# it's running.
return
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
self.driver_worker.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
return
UniProcExecutorAsync = UniProcExecutor

View File

@ -265,9 +265,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.fused_experts == fused_experts
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("BatchedTritonExperts %s", self.moe)
@ -375,8 +372,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
return self.forward(
x=x,
@ -393,7 +392,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def forward_cuda(
self,
@ -412,6 +416,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
@ -425,7 +433,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts(
@ -730,7 +743,8 @@ class FusedMoE(torch.nn.Module):
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
if not isinstance(quant_method, Fp8MoEMethod):
if not isinstance(quant_method,
(Fp8MoEMethod, UnquantizedFusedMoEMethod)):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
@ -821,6 +835,15 @@ class FusedMoE(torch.nn.Module):
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
with self.expert_map.device:
self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,

View File

@ -776,6 +776,24 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, DeepseekV2MoE):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@ -931,9 +949,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
weight_name: str) -> Optional[int]:
if hasattr(config,
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
> 0):
if (hasattr(config, "num_nextn_predict_layers")
and config.num_nextn_predict_layers > 0):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):

View File

@ -543,6 +543,13 @@ class MixtureOfExperts(Protocol):
"""
...
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
...
def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
return isinstance(model, MixtureOfExperts)

View File

@ -177,3 +177,19 @@ class EngineCoreRequestType(enum.Enum):
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'
class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_size: int
new_data_parallel_rank: int
new_data_parallel_rank_local: int
new_data_parallel_master_ip: str
new_data_parallel_master_port: int
class ReconfigureRankType(enum.IntEnum):
"""
Rank type for reconfiguring distributed request.
"""
KEEP_CURRENT_RANK = -1
SHUTDOWN_CURRENT_RANK = -2

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Any, Optional, Union
@ -608,6 +609,63 @@ class AsyncLLM(EngineClient):
return await self.engine_core.collective_rpc_async(
method, timeout, args, kwargs)
async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
"""Wait for all requests to be drained."""
start_time = time.time()
while time.time() - start_time < drain_timeout:
if not self.engine_core.dp_engines_running():
logger.info("Engines are idle, requests have been drained")
return
logger.info(
"Engines are still running, waiting for requests to drain...")
await asyncio.sleep(1) # Wait 1 second before checking again
raise TimeoutError(f"Timeout reached after {drain_timeout} seconds "
"waiting for requests to drain.")
async def scale_elastic_ep(self,
new_data_parallel_size: int,
drain_timeout: int = 300):
"""
Scale up or down the data parallel size by adding or removing
engine cores.
Args:
new_data_parallel_size: The new number of data parallel workers
drain_timeout:
Maximum time to wait for requests to drain (seconds)
"""
old_data_parallel_size = \
self.vllm_config.parallel_config.data_parallel_size
if old_data_parallel_size == new_data_parallel_size:
logger.info("Data parallel size is already %s, skipping scale",
new_data_parallel_size)
return
logger.info(
"Waiting for requests to drain before "
"scaling up to %s engines...", new_data_parallel_size)
await self.wait_for_requests_to_drain(drain_timeout)
logger.info(
"Requests have been drained, proceeding with scale "
"to %s engines", new_data_parallel_size)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size:
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=self.vllm_config,
log_stats=self.log_stats,
engine_num=new_data_parallel_size,
custom_stat_loggers=None,
)
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
else:
for _ in range(old_data_parallel_size - new_data_parallel_size):
self.stat_loggers.pop()
@property
def is_running(self) -> bool:
# Is None before the loop is started.

View File

@ -200,11 +200,41 @@ class CoordinatorProc:
# Ignore subscription messages.
continue
decoded = msgspec.msgpack.decode(buffer)
if isinstance(decoded, (list, tuple)) and len(
decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP":
# Handle scale up notification
new_engine_count = decoded[1]
current_count = len(self.engines)
if new_engine_count > current_count:
for _ in range(new_engine_count - current_count):
self.engines.append(EngineState())
# NOTE(yongji): handle the case
# where newly started engines have current_wave = 0
# if existing engines just finished a wave
# and engine_running isn't updated yet at
# CoordinatorProc requests routed to newly started
# engines may not wake up existing engines, as long
# as 0 < request.wave < existing engines'
# current_wave
# we note that 0 is the wave number for the new
# engine
self.engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s "
"engines", current_count, new_engine_count)
else:
self.engines = self.engines[:new_engine_count]
logger.info(
"DPCoordinator scaled down from %s to %s "
"engines", current_count, new_engine_count)
continue # Skip normal engine notification processing
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
engine_to_exclude, wave = decoded
if not self.engines_running:
if wave < self.current_wave:
# If the wave number is stale, ensure the message

View File

@ -32,7 +32,9 @@ from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
@ -77,6 +79,8 @@ class EngineCore:
self.model_executor.register_failure_callback(
executor_fail_callback)
self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
@ -137,12 +141,23 @@ class EngineCore:
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
available_gpu_memory = \
self.model_executor.determine_available_memory()
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = \
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
available_gpu_memory = [
self.available_gpu_memory_for_kv_cache
] * len(kv_cache_specs)
else:
# Profiles the peak memory usage of the model to determine how
# much memory can be allocated for kv cache.
available_gpu_memory = (
self.model_executor.determine_available_memory())
self.available_gpu_memory_for_kv_cache = \
available_gpu_memory[0]
else:
# Attention free models don't need memory for kv cache
available_gpu_memory = [0] * len(kv_cache_specs)
@ -989,6 +1004,50 @@ class DPEngineCoreProc(EngineCoreProc):
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group)
self.shutdown()
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size
parallel_config.data_parallel_size = \
reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != -1:
parallel_config.data_parallel_rank = \
reconfig_request.new_data_parallel_rank
# local rank specifies device visibility, it should not be changed
assert reconfig_request.new_data_parallel_rank_local == \
ReconfigureRankType.KEEP_CURRENT_RANK
parallel_config.data_parallel_master_ip = \
reconfig_request.new_data_parallel_master_ip
parallel_config.data_parallel_master_port = \
reconfig_request.new_data_parallel_master_port
if reconfig_request.new_data_parallel_rank != -2:
self.dp_rank = parallel_config.data_parallel_rank
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = \
parallel_config.data_parallel_master_port
self.model_executor.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
else:
logger.info("Distributed environment reinitialized for DP rank %s",
self.dp_rank)
class DPEngineCoreActor(DPEngineCoreProc):
"""

View File

@ -21,9 +21,11 @@ import zmq.asyncio
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
@ -162,6 +164,9 @@ class EngineCoreClient(ABC):
running state."""
raise NotImplementedError
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
@ -910,14 +915,30 @@ class DPAsyncMPClient(AsyncMPClient):
events = await poller.poll()
if not self.engines_running and len(events) == 2 or (
events[0][0] == first_req_rcv_socket):
# Send a message to notify the coordinator that
# Check if this is a regular request notification or
# scale up notification
buf = first_req_rcv_socket.recv(
flags=zmq.NOBLOCK).result()
decoded = msgspec.msgpack.decode(buf)
if isinstance(
decoded,
(list, tuple)) and len(decoded) == 2 and decoded[
0] == "SCALE_ELASTIC_EP":
# Extract new engine count from the decoded message
new_engine_count = decoded[1]
# Send scale up notification to coordinator
scale_msg = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_engine_count))
await socket.send(scale_msg)
continue
# we're sending a request while the engines are
# paused, so that it can wake the others up
# (to run dummy EP loop).
assert decoded[0] == "FIRST_REQ"
target_eng_index = decoded[1]
self.engines_running = True
buf = first_req_rcv_socket.recv(
flags=zmq.NOBLOCK).result()
target_eng_index = int.from_bytes(buf, "little")
msg = msgspec.msgpack.encode(
(target_eng_index, self.current_wave))
await socket.send(msg)
@ -953,7 +974,8 @@ class DPAsyncMPClient(AsyncMPClient):
chosen_engine)
if not self.engines_running:
# Notify coordinator that we're sending a request
await self.first_req_send_socket.send(chosen_engine)
req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine))
await self.first_req_send_socket.send(req_msg)
await to_await
@ -1047,3 +1069,156 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
engine: EngineIdentity) -> None:
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)
async def _send_reconfig_message(
self, reconfig_request: ReconfigureDistributedRequest,
engine: EngineIdentity) -> asyncio.Future:
"""Send reconfiguration message and return the result future without
waiting for completion."""
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(self.client_index, call_id, "reinitialize_distributed",
(reconfig_request, ))))
await self._send_input_message(message, engine, reconfig_request)
self._ensure_output_queue_task()
return future
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
"""Scale elastic EP data parallel size"""
cur_data_parallel_size = len(self.core_engines)
assert new_data_parallel_size != cur_data_parallel_size, (
f"new_data_parallel_size {new_data_parallel_size} must be "
f"different from cur_data_parallel_size {cur_data_parallel_size}")
assert self.vllm_config.parallel_config.data_parallel_backend == \
"ray", ("Only ray DP backend supports scaling elastic EP")
scale_up = new_data_parallel_size > cur_data_parallel_size
if scale_up:
await self._scale_up_elastic_ep(cur_data_parallel_size,
new_data_parallel_size)
else:
await self._scale_down_elastic_ep(cur_data_parallel_size,
new_data_parallel_size)
async def _scale_up_elastic_ep(self, cur_data_parallel_size: int,
new_data_parallel_size: int) -> None:
"""Scale up the data parallel size by creating new engine cores
and reconfiguring existing ones."""
cur_data_parallel_size = len(self.core_engines)
# Phase 1: Send reconfigure messages to all existing engines and wait
# for them to be sent
reconfig_futures = []
self.vllm_config.parallel_config.data_parallel_master_port = \
get_open_port()
for engine in self.core_engines:
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=\
ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.
data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.
data_parallel_master_port)
future = await self._send_reconfig_message(reconfig_request,
engine)
reconfig_futures.append(future)
logger.info("All reconfigure messages sent, starting engine creation")
# Phase 2: Create new engines now that reconfig messages have been sent
# self.resources.engine_manager is guaranteed to be
# CoreEngineActorManager for RayDPClient
assert isinstance(self.resources.engine_manager,
CoreEngineActorManager)
self.resources.engine_manager.scale_up_elastic_ep(
self.vllm_config, new_data_parallel_size)
# Create new CoreEngine objects for the new engines
new_engine_identities = set()
for i in range(cur_data_parallel_size, new_data_parallel_size):
new_engine = i.to_bytes(2, "little")
self.core_engines.append(new_engine)
new_engine_identities.add(new_engine)
# Wait for ready messages from new engines on the input socket
sync_input_socket = zmq.Socket.shadow(self.input_socket)
while new_engine_identities:
if not sync_input_socket.poll(timeout=600_000):
raise TimeoutError(
"Timed out waiting for new engines to send initial "
"message on input socket.")
identity, _ = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity)
# Phase 3: Wait for all existing engines to complete reconfiguration
logger.info("Waiting for existing engines to complete reconfiguration")
await asyncio.gather(*reconfig_futures)
# Notify coordinator about scale up through existing
# stats_update_task connection
self._ensure_stats_update_task()
scale_up_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size))
await self.first_req_send_socket.send(scale_up_marker)
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
logger.info(
"[Elastic EP] Scale up completed, new data parallel size: %s",
new_data_parallel_size)
async def _scale_down_elastic_ep(self, cur_data_parallel_size: int,
new_data_parallel_size: int) -> None:
"""Scale down the data parallel size by shutting down and
reconfiguring existing engine cores."""
cur_data_parallel_size = len(self.core_engines)
self.vllm_config.parallel_config.data_parallel_master_port = \
get_open_port()
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=\
ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.
data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.
data_parallel_master_port)
if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK
future = await self._send_reconfig_message(reconfig_request,
engine)
reconfig_futures.append(future)
for _ in range(new_data_parallel_size, cur_data_parallel_size):
self.core_engines.pop()
await asyncio.gather(*reconfig_futures)
assert isinstance(self.resources.engine_manager,
CoreEngineActorManager)
self.resources.engine_manager.scale_down_elastic_ep(
cur_data_parallel_size, new_data_parallel_size)
self._ensure_stats_update_task()
scale_down_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size))
await self.first_req_send_socket.send(scale_down_marker)
self.vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size)

View File

@ -174,16 +174,21 @@ class CoreEngineActorManager:
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor")
self.env_vars_dict = {
name: os.environ[name]
for name in env_vars_list if name in os.environ
}
runtime_env = RuntimeEnv(env_vars=self.env_vars_dict)
self.addresses = addresses
self.executor_class = executor_class
self.log_stats = log_stats
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor")
env_vars_dict = {
name: os.environ[name]
for name in env_vars_set if name in os.environ
}
runtime_env = RuntimeEnv(env_vars=env_vars_dict)
if ray.is_initialized():
logger.info(
@ -208,6 +213,7 @@ class CoreEngineActorManager:
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
self.placement_group_is_local = []
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
@ -231,6 +237,7 @@ class CoreEngineActorManager:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
self.placement_group_is_local.append(local_client)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
@ -242,6 +249,9 @@ class CoreEngineActorManager:
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
"""
Create placement groups for data parallel.
"""
import ray
from ray._private.state import available_resources_per_node
@ -250,10 +260,11 @@ class CoreEngineActorManager:
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
@ -293,7 +304,7 @@ class CoreEngineActorManager:
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
if len(placement_groups) == num_pg_to_create:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
@ -305,6 +316,204 @@ class CoreEngineActorManager:
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
@staticmethod
def add_dp_placement_groups(
old_vllm_config: VllmConfig, new_data_parallel_size: int
) -> tuple[list["PlacementGroup"], list[int]]:
"""
Add placement groups for new data parallel size.
"""
import ray
from ray._private.state import (available_resources_per_node,
total_resources_per_node)
from ray.util.state import list_nodes
old_dp_size = old_vllm_config.parallel_config.data_parallel_size
num_pg_to_create = new_data_parallel_size - old_dp_size
if num_pg_to_create <= 0:
return [], []
dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip
world_size = old_vllm_config.parallel_config.world_size
nodes = list_nodes()
nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
total_resources = total_resources_per_node()
placement_groups = []
local_dp_ranks = []
num_pg_created = 0
for node in nodes:
if num_pg_created >= num_pg_to_create:
break
node_ip = node.node_ip
node_id = node.node_id
available_gpus = int(available_resources[node_id]["GPU"])
# Get total GPUs on this node from the node's resources
# Ray stores node resources with node ID as key
total_gpus = int(total_resources[node_id]["GPU"])
# Calculate used GPUs and used engines on this node
used_gpus = max(0, total_gpus - available_gpus)
used_engines_on_node = used_gpus // world_size
# Calculate how many new engines this node can accommodate
available_engine_count = available_gpus // world_size
# Create placement groups for new engines on this node
for i in range(available_engine_count):
if num_pg_created >= num_pg_to_create:
break
rank = old_dp_size + num_pg_created
# Create bundles with node constraint for master node
if node_ip == dp_master_ip:
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
else:
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{rank}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
# Local rank starts from the number of engines already used
# on this node
local_rank = used_engines_on_node + i
local_dp_ranks.append(local_rank)
num_pg_created += 1
return placement_groups, local_dp_ranks
def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig,
new_data_parallel_size: int) -> None:
import copy
import ray
from ray.runtime_env import RuntimeEnv
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
cur_data_parallel_size = len(self.local_engine_actors) + \
len(self.remote_engine_actors)
assert new_data_parallel_size > cur_data_parallel_size, (
f"New data parallel size {new_data_parallel_size} must be greater "
f"than current data parallel size {cur_data_parallel_size} "
"for scale up")
placement_groups, local_dp_ranks = \
self.add_dp_placement_groups(
cur_vllm_config, new_data_parallel_size)
world_size = cur_vllm_config.parallel_config.world_size
dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip
new_local_engines = 0
runtime_env = RuntimeEnv(env_vars=self.env_vars_dict
| {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"})
for i, (pg,
local_rank) in enumerate(zip(placement_groups,
local_dp_ranks)):
rank = cur_data_parallel_size + i
dp_vllm_config = copy.deepcopy(cur_vllm_config)
dp_vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
dp_vllm_config.parallel_config.placement_group = pg
# Check if this placement group is on the head node
local_client = any(
bundle.get("node:" + dp_master_ip, 0) > 0
for bundle in pg.bundle_specs)
if local_client:
new_local_engines += 1
# Update data_parallel_size_local
dp_vllm_config.parallel_config.data_parallel_size_local = (
cur_vllm_config.parallel_config.data_parallel_size_local +
new_local_engines)
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
),
runtime_env=runtime_env).remote(
vllm_config=dp_vllm_config,
executor_class=self.executor_class,
log_stats=self.log_stats,
local_client=local_client,
addresses=self.addresses,
dp_rank=rank,
local_dp_rank=local_rank)
if local_client:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
self.created_placement_groups.append(pg)
self.placement_group_is_local.append(local_client)
ray.get([
actor.wait_for_init.remote()
for actor in (self.local_engine_actors[-new_local_engines:]
if new_local_engines > 0 else []) +
self.remote_engine_actors[-(len(placement_groups) -
new_local_engines):]
])
actors = (self.local_engine_actors[-new_local_engines:]
if new_local_engines > 0 else []) + \
self.remote_engine_actors[-(len(placement_groups) -
new_local_engines):]
for actor in actors:
self.run_refs.append(actor.run.remote())
cur_vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
# Update old_vllm_config with new data_parallel_size_local if any new
# local engines were added
if new_local_engines > 0:
cur_vllm_config.parallel_config.data_parallel_size_local += \
new_local_engines
def scale_down_elastic_ep(self, cur_data_parallel_size: int,
new_data_parallel_size: int) -> None:
import ray
assert cur_data_parallel_size > new_data_parallel_size, (
f"cur_data_parallel_size {cur_data_parallel_size} must be greater "
f"than new_data_parallel_size {new_data_parallel_size} "
"for scale down")
for _ in range(cur_data_parallel_size - new_data_parallel_size):
pg = self.created_placement_groups.pop()
is_local = self.placement_group_is_local.pop()
if is_local:
self.local_engine_actors.pop()
else:
self.remote_engine_actors.pop()
ray.util.remove_placement_group(pg)
def get_run_refs(self):
return self.run_refs

View File

@ -6,6 +6,7 @@ from typing import Union
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
@ -62,3 +63,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return FutureWrapper(refs[0])
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
self._run_workers("reinitialize_distributed", reconfig_request)
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
return

View File

@ -49,7 +49,7 @@ class CPUModelRunner(GPUModelRunner):
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
replace_tensor(self.input_batch.block_table, k, k[:-4])
def load_model(self) -> None:
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)

View File

@ -1745,8 +1745,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None:
def load_model(self, eep_scale_up: bool = False) -> None:
"""
Args:
eep_scale_up: the model loading is for elastic EP scale up.
"""
logger.info("Starting to load model %s...", self.model_config.model)
if eep_scale_up:
from vllm.distributed.parallel_state import get_ep_group
num_local_physical_experts = torch.empty(1,
dtype=torch.int32,
device="cpu")
torch.distributed.broadcast(num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_load, old_global_expert_indices = (
EplbState.recv_state())
num_logical_experts = global_expert_load.shape[1]
self.parallel_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts)
assert old_global_expert_indices.shape[
1] % num_local_physical_experts == 0
old_ep_size = old_global_expert_indices.shape[
1] // num_local_physical_experts
rank_mapping = {
old_ep_rank: old_ep_rank
for old_ep_rank in range(old_ep_size)
}
else:
global_expert_load = None
old_global_expert_indices = None
rank_mapping = None
with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config)
@ -1788,6 +1820,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model,
self.device,
self.parallel_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
def save_tensorized_model(

View File

@ -26,6 +26,7 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
@ -191,8 +192,9 @@ class Worker(WorkerBase):
else:
from contextlib import nullcontext
context = nullcontext()
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with context:
self.model_runner.load_model()
self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@ -384,6 +386,161 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int,
new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding "
"before scaling down...")
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(self.model_runner.model,
execute_shuffle=True,
global_expert_load=None,
rank_mapping=rank_mapping)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self, old_ep_size: int, new_ep_size: int,
global_expert_load: Optional[torch.Tensor]) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding "
"after scaling up...")
rank_mapping = {
old_ep_rank: old_ep_rank
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True,
global_expert_load=global_expert_load,
rank_mapping=rank_mapping)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = \
reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != \
ReconfigureRankType.KEEP_CURRENT_RANK:
parallel_config.data_parallel_rank = \
reconfig_request.new_data_parallel_rank
if reconfig_request.new_data_parallel_rank_local != \
ReconfigureRankType.KEEP_CURRENT_RANK:
parallel_config.data_parallel_rank_local = \
reconfig_request.new_data_parallel_rank_local
parallel_config.data_parallel_master_ip = \
reconfig_request.new_data_parallel_master_ip
parallel_config.data_parallel_master_port = \
reconfig_request.new_data_parallel_master_port
def _reconfigure_moe(self, old_ep_size: int,
new_ep_size: int) -> Optional[torch.Tensor]:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEParallelConfig)
parallel_config = self.vllm_config.parallel_config
moe_modules = [
module for module in self.model_runner.model.modules()
if module.__class__.__name__ == "FusedMoE"
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(module.moe_config.num_local_experts == num_local_experts
for module in moe_modules), (
"All MoE modules must have the same number of experts")
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = \
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
parallel_config.num_redundant_experts = (
new_physical_experts -
self.model_runner.eplb_state.logical_replica_count.shape[1])
global_expert_load = None
else:
num_local_physical_experts = torch.tensor([num_local_experts],
dtype=torch.int32,
device="cpu")
torch.distributed.broadcast(num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0)
num_local_physical_experts = num_local_physical_experts.item()
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False)
parallel_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1])
prepare_communication_buffer_for_model(self.model_runner.model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts)
return global_expert_load
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_ep_group)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
).world_size * get_pp_group().world_size
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_load is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size,
global_expert_load)
def save_sharded_state(
self,
path: str,