mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:47:22 +08:00
Elastic Expert Parallel Initial Support (#20775)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
5782581acf
commit
217937221b
57
examples/online_serving/elastic_ep/bench.sh
Normal file
57
examples/online_serving/elastic_ep/bench.sh
Normal file
@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"
|
||||
LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0"
|
||||
HOST="localhost"
|
||||
PORT=8006
|
||||
NUM_PROMPTS=20
|
||||
REQUEST_RATE=5
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--model)
|
||||
MODEL_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--local-model)
|
||||
MODEL_NAME=$LOCAL_MODEL_PATH
|
||||
shift
|
||||
;;
|
||||
--host)
|
||||
HOST="$2"
|
||||
shift 2
|
||||
;;
|
||||
--port)
|
||||
PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--num-prompts)
|
||||
NUM_PROMPTS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--request-rate)
|
||||
REQUEST_RATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " --model MODEL_NAME Set model name or path (default: deepseek-ai/DeepSeek-V2-Lite)"
|
||||
echo " --local-model Use local model path (convenience option)"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "Use -h or --help for usage information"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
vllm bench serve \
|
||||
--model $MODEL_NAME \
|
||||
--host $HOST \
|
||||
--port $PORT \
|
||||
--num-prompts $NUM_PROMPTS \
|
||||
--request-rate $REQUEST_RATE
|
||||
53
examples/online_serving/elastic_ep/scale.py
Normal file
53
examples/online_serving/elastic_ep/scale.py
Normal file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def scale(host, port, new_dp_size):
|
||||
url = f"http://{host}:{port}/scale_elastic_ep"
|
||||
payload = {"new_data_parallel_size": new_dp_size}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
print(f"Sending scale request to {url}")
|
||||
print(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=300)
|
||||
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
print("Scale up/down request successful!")
|
||||
return True
|
||||
else:
|
||||
print("Scale up/down request failed!")
|
||||
return False
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Request failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test scale up/down functionality")
|
||||
parser.add_argument("--host", default="localhost", help="API server host")
|
||||
parser.add_argument("--port", type=int, default=8006, help="API server port")
|
||||
parser.add_argument(
|
||||
"--new-dp-size", type=int, default=2, help="New data parallel size"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
success = scale(args.host, args.port, args.new_dp_size)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
72
examples/online_serving/elastic_ep/serve_deepseek_v2.sh
Normal file
72
examples/online_serving/elastic_ep/serve_deepseek_v2.sh
Normal file
@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
HOST="0.0.0.0"
|
||||
PORT=8006
|
||||
DATA_PARALLEL_SIZE=4
|
||||
REDUNDANT_EXPERTS=0
|
||||
LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0"
|
||||
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--dp)
|
||||
DATA_PARALLEL_SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--re)
|
||||
REDUNDANT_EXPERTS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--host)
|
||||
HOST="$2"
|
||||
shift 2
|
||||
;;
|
||||
--port)
|
||||
PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--model)
|
||||
MODEL_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--local-model)
|
||||
MODEL_NAME=$LOCAL_MODEL_PATH
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " --dp SIZE Set data parallel size (default: 4)"
|
||||
echo " --re SIZE Set redundant experts (default: 0)"
|
||||
echo " --host HOST Set host address (default: 0.0.0.0)"
|
||||
echo " --port PORT Set port number (default: 8006)"
|
||||
echo " --model MODEL_NAME Set model name or path"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
echo "Use -h or --help for usage information"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS"
|
||||
|
||||
export RAY_DEDUP_LOGS=0
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_ALL2ALL_BACKEND="pplx"
|
||||
export VLLM_USE_DEEP_GEMM=1
|
||||
|
||||
vllm serve $MODEL_NAME \
|
||||
--data-parallel-size $DATA_PARALLEL_SIZE \
|
||||
--data-parallel-size-local $DATA_PARALLEL_SIZE \
|
||||
--data-parallel-backend ray \
|
||||
--enforce-eager \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--num-redundant-experts $REDUNDANT_EXPERTS \
|
||||
--trust-remote-code \
|
||||
--host $HOST \
|
||||
--port $PORT
|
||||
92
tools/ep_kernels/elastic_ep/eep_nvshmem.patch
Normal file
92
tools/ep_kernels/elastic_ep/eep_nvshmem.patch
Normal file
@ -0,0 +1,92 @@
|
||||
From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001
|
||||
From: Yongji Wu <wuyongji317@gmail.com>
|
||||
Date: Tue, 20 May 2025 13:41:12 -0700
|
||||
Subject: [PATCH] fix reinit issues due to states not cleaned up
|
||||
|
||||
fix double free
|
||||
---
|
||||
src/host/init/init.cu | 10 ++++++++++
|
||||
.../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++
|
||||
src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++
|
||||
3 files changed, 30 insertions(+)
|
||||
|
||||
diff --git a/src/host/init/init.cu b/src/host/init/init.cu
|
||||
index b1c5dbf..1fecb4b 100644
|
||||
--- a/src/host/init/init.cu
|
||||
+++ b/src/host/init/init.cu
|
||||
@@ -43,6 +43,8 @@
|
||||
#include "internal/host/nvshmemi_types.h"
|
||||
#include "internal/host/shared_memory.h"
|
||||
#include "internal/host/nvshmemi_symmetric_heap.hpp"
|
||||
+// eep-dev
|
||||
+#include "internal/host/nvshmemi_mem_transport.hpp"
|
||||
|
||||
extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d;
|
||||
static std::map<void *, int> registered_device_states;
|
||||
@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) {
|
||||
/* Multi-init Multi-fini*/
|
||||
nvshmemi_state = NULL;
|
||||
nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0;
|
||||
+
|
||||
+ // eep-dev
|
||||
+ nvshmemi_mem_p2p_transport::destroy_instance();
|
||||
+ nvshmemi_mem_remote_transport::destroy_instance();
|
||||
+ free(nvshmemi_default_session);
|
||||
+ nvshmemi_default_session = nullptr;
|
||||
+ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false;
|
||||
+
|
||||
nvshmemi_is_device_state_ready = false;
|
||||
} else
|
||||
nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle);
|
||||
diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp
|
||||
index 2495844..e4f408a 100644
|
||||
--- a/src/include/internal/host/nvshmemi_mem_transport.hpp
|
||||
+++ b/src/include/internal/host/nvshmemi_mem_transport.hpp
|
||||
@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final {
|
||||
return p2p_objref_;
|
||||
}
|
||||
}
|
||||
+ // eep-dev
|
||||
+ static void destroy_instance(void) {
|
||||
+ if (p2p_objref_ != nullptr) {
|
||||
+ delete p2p_objref_;
|
||||
+ p2p_objref_ = nullptr;
|
||||
+ }
|
||||
+ }
|
||||
|
||||
void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj);
|
||||
|
||||
@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final {
|
||||
}
|
||||
}
|
||||
|
||||
+ // eep-dev
|
||||
+ static void destroy_instance(void) {
|
||||
+ if (remote_objref_ != nullptr) {
|
||||
+ delete remote_objref_;
|
||||
+ remote_objref_ = nullptr;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size);
|
||||
/* On-demand registration and release of memory */
|
||||
int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx,
|
||||
diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp
|
||||
index a1fa748..788fa96 100644
|
||||
--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp
|
||||
+++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp
|
||||
@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi
|
||||
// Discover the network for bootstrap, if not done previously.
|
||||
// This code needs to be stateful to be able to be called multiple times by the caller
|
||||
BOOTSTRAP_CHECK(bootstrap_net_init());
|
||||
+ // eep-dev
|
||||
+ if (handle->pre_init_ops != nullptr) {
|
||||
+ BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
|
||||
+ handle->pre_init_ops = nullptr;
|
||||
+ }
|
||||
if (handle->pre_init_ops == nullptr) {
|
||||
BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1);
|
||||
handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id;
|
||||
--
|
||||
2.43.0
|
||||
|
||||
86
tools/ep_kernels/elastic_ep/install_eep_libraries.sh
Normal file
86
tools/ep_kernels/elastic_ep/install_eep_libraries.sh
Normal file
@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
# Default workspace directory
|
||||
WORKSPACE=$(pwd)/eep_kernels_workspace
|
||||
INSTALL_NVSHMEM=true
|
||||
|
||||
# Parse command line arguments
|
||||
while getopts "w:n" opt; do
|
||||
case $opt in
|
||||
w)
|
||||
WORKSPACE="$OPTARG"
|
||||
;;
|
||||
n)
|
||||
INSTALL_NVSHMEM=false
|
||||
;;
|
||||
\?)
|
||||
echo "Invalid option: -$OPTARG" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ ! -d "$WORKSPACE" ]; then
|
||||
mkdir -p $WORKSPACE
|
||||
fi
|
||||
|
||||
|
||||
# install dependencies if not installed
|
||||
pip3 install cmake torch ninja
|
||||
|
||||
# build nvshmem
|
||||
pushd $WORKSPACE
|
||||
# Reset NVSHMEM build if requested
|
||||
if [ "$INSTALL_NVSHMEM" = true ]; then
|
||||
mkdir -p nvshmem_src
|
||||
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
|
||||
tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1
|
||||
pushd nvshmem_src
|
||||
wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
|
||||
git init
|
||||
git apply -vvv nvshmem.patch
|
||||
git apply --reject --whitespace=fix ../../eep_nvshmem.patch
|
||||
else
|
||||
pushd nvshmem_src
|
||||
fi
|
||||
|
||||
# assume CUDA_HOME is set correctly
|
||||
if [ -z "$CUDA_HOME" ]; then
|
||||
echo "CUDA_HOME is not set, please set it to your CUDA installation directory."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# disable all features except IBGDA
|
||||
export NVSHMEM_IBGDA_SUPPORT=1
|
||||
|
||||
export NVSHMEM_SHMEM_SUPPORT=0
|
||||
export NVSHMEM_UCX_SUPPORT=0
|
||||
export NVSHMEM_USE_NCCL=0
|
||||
export NVSHMEM_PMIX_SUPPORT=0
|
||||
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
|
||||
export NVSHMEM_USE_GDRCOPY=0
|
||||
export NVSHMEM_IBRC_SUPPORT=0
|
||||
export NVSHMEM_BUILD_TESTS=0
|
||||
export NVSHMEM_BUILD_EXAMPLES=0
|
||||
export NVSHMEM_MPI_SUPPORT=0
|
||||
export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
|
||||
export NVSHMEM_BUILD_TXZ_PACKAGE=0
|
||||
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
|
||||
|
||||
cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install
|
||||
cmake --build $WORKSPACE/nvshmem_build/ --target install
|
||||
|
||||
popd
|
||||
|
||||
export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH
|
||||
|
||||
# build and install pplx, require pytorch installed
|
||||
pushd $WORKSPACE
|
||||
git clone https://github.com/ppl-ai/pplx-kernels
|
||||
cd pplx-kernels
|
||||
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
|
||||
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
|
||||
PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v
|
||||
|
||||
@ -2008,6 +2008,19 @@ class ParallelConfig:
|
||||
aggregated_has_unfinished = bool(tensor.item())
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@staticmethod
|
||||
def sync_kv_cache_memory_size(dp_group: "ProcessGroup",
|
||||
kv_cache_memory: int) -> int:
|
||||
if kv_cache_memory == -1:
|
||||
kv_cache_memory = torch.iinfo(torch.int64).max
|
||||
tensor = torch.tensor([kv_cache_memory],
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
# we cannot use broadcast for stateless dp group since it depends
|
||||
# on global rank
|
||||
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
|
||||
return tensor.item()
|
||||
|
||||
def compute_hash(self):
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
|
||||
@ -29,12 +29,15 @@ physical experts.
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import all_gather, all_reduce
|
||||
from torch.distributed import ProcessGroup, all_gather, all_reduce
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_ep_group, get_node_count
|
||||
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
|
||||
in_the_same_node_as)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
@ -172,6 +175,9 @@ class EplbState:
|
||||
model: MixtureOfExperts,
|
||||
device: torch.device,
|
||||
parallel_config: ParallelConfig,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
old_global_expert_indices: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
) -> "EplbState":
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
@ -185,8 +191,16 @@ class EplbState:
|
||||
physical_to_logical_map_list,
|
||||
device=device,
|
||||
)
|
||||
# Assuming 8 GPUs per node, this supports up to
|
||||
# (1023 + 1) / 8 = 128 nodes for now.
|
||||
# TODO(rui): make this configurable
|
||||
MAX_EXPERT_REDUNDANCY = 1023
|
||||
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
|
||||
f"num_redundant_experts {model.num_redundant_experts} "
|
||||
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
|
||||
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
|
||||
logical_to_physical_map = torch.full(
|
||||
(model.num_logical_experts, model.num_redundant_experts + 1),
|
||||
(model.num_logical_experts, max_slots_per_logical_expert),
|
||||
-1,
|
||||
device=device,
|
||||
)
|
||||
@ -235,11 +249,63 @@ class EplbState:
|
||||
expert_rearrangement_step = max(
|
||||
0, eplb_step_interval - eplb_step_interval // 4)
|
||||
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (model.num_moe_layers,
|
||||
model.num_logical_experts)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}")
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = (rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
))
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
if global_expert_load is not None:
|
||||
rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices,
|
||||
new_physical_to_logical_map,
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
False,
|
||||
rank_mapping,
|
||||
)
|
||||
expert_rearrangement_step = 0
|
||||
|
||||
return cls(
|
||||
physical_to_logical_map,
|
||||
@ -337,7 +403,10 @@ class EplbState:
|
||||
|
||||
def rearrange(self,
|
||||
model: MixtureOfExperts,
|
||||
is_profile: bool = False) -> None:
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int, int]] = None) -> None:
|
||||
"""
|
||||
Rearrange the experts according to the current load.
|
||||
"""
|
||||
@ -353,42 +422,79 @@ class EplbState:
|
||||
logger.info("Rearranging experts %s...",
|
||||
"(profile)" if is_profile else "")
|
||||
|
||||
# This mapping is only used here, so we do not store it in the state
|
||||
physical_expert_start = ep_rank * model.num_local_physical_experts
|
||||
physical_expert_end = (physical_expert_start +
|
||||
model.num_local_physical_experts)
|
||||
# (num_moe_layers, num_local_physical_experts)
|
||||
local_physical_to_logical_map = self.physical_to_logical_map[
|
||||
:,
|
||||
physical_expert_start:physical_expert_end,
|
||||
]
|
||||
if global_expert_load is None:
|
||||
# This mapping is only used here, so we do not store it in the state
|
||||
physical_expert_start = ep_rank * model.num_local_physical_experts
|
||||
physical_expert_end = (physical_expert_start +
|
||||
model.num_local_physical_experts)
|
||||
# (num_moe_layers, num_local_physical_experts)
|
||||
local_physical_to_logical_map = self.physical_to_logical_map[
|
||||
:,
|
||||
physical_expert_start:physical_expert_end,
|
||||
]
|
||||
|
||||
# Map the local physical expert load to global logical experts
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
dtype=self.expert_load_window.dtype,
|
||||
device=self.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
self.expert_load_window).long(),
|
||||
src=self.expert_load_window,
|
||||
)
|
||||
# Map the local physical expert load to global logical experts
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
dtype=self.expert_load_window.dtype,
|
||||
device=self.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
self.expert_load_window).long(),
|
||||
src=self.expert_load_window,
|
||||
)
|
||||
|
||||
# Perform all-reduce to get the expert load across all ranks
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
all_reduce(global_expert_load_window, group=ep_group)
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
model.num_moe_layers, model.num_logical_experts,
|
||||
self.physical_to_logical_map.shape[1]
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(metadata,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0)
|
||||
|
||||
# Perform all-reduce to get the expert load across all ranks
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
all_reduce(global_expert_load_window, group=ep_group)
|
||||
|
||||
if not execute_shuffle:
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = self.physical_to_logical_map
|
||||
torch.distributed.broadcast(old_global_expert_indices,
|
||||
group=ep_group,
|
||||
group_src=0)
|
||||
return global_expert_load_window
|
||||
else:
|
||||
assert execute_shuffle
|
||||
global_expert_load_window = global_expert_load
|
||||
|
||||
# TODO(bowen): Treat differently for prefill and decode nodes
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1
|
||||
for new_rank in rank_mapping.values())
|
||||
num_replicas = num_replicas // ep_group.size(
|
||||
) * num_gpus # handle num replicas change
|
||||
else:
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
self.num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
@ -414,10 +520,24 @@ class EplbState:
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
is_profile,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
if not is_profile:
|
||||
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
|
||||
if self.physical_to_logical_map.shape[
|
||||
1] != new_physical_to_logical_map.shape[1]:
|
||||
self.physical_to_logical_map = new_physical_to_logical_map.to(
|
||||
self.physical_to_logical_map.device)
|
||||
else:
|
||||
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0,
|
||||
self.logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
self.logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
@ -430,3 +550,69 @@ class EplbState:
|
||||
" (profile) " if is_profile else " ",
|
||||
time_end - time_start,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Receive the expert load and old placement from the master rank.
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata,
|
||||
group=ep_group.cpu_group,
|
||||
group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist())
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
all_reduce(global_expert_load, group=ep_group.device_group)
|
||||
old_global_expert_indices = torch.empty(
|
||||
(num_moe_layers, num_old_physical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0)
|
||||
|
||||
return global_expert_load, old_global_expert_indices
|
||||
|
||||
|
||||
def _node_count_with_rank_mapping(
|
||||
pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
rank_mapping: dict[int, int],
|
||||
) -> int:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
else:
|
||||
world_size = pg.world_size
|
||||
|
||||
if world_size == 1:
|
||||
return 1
|
||||
|
||||
# Build node assignment map
|
||||
node_assignment = [0] * world_size # rank -> node_id
|
||||
next_node_id = 0
|
||||
|
||||
for current_rank in range(world_size):
|
||||
if node_assignment[current_rank] != 0:
|
||||
continue # Already assigned to a node
|
||||
|
||||
assert current_rank in rank_mapping
|
||||
if rank_mapping[current_rank] == -1:
|
||||
continue # Pending shutdown
|
||||
|
||||
# Assign current rank to a new node
|
||||
next_node_id += 1
|
||||
node_assignment[current_rank] = next_node_id
|
||||
|
||||
# Find all ranks on the same node as current_rank
|
||||
same_node_flags = in_the_same_node_as(pg, current_rank)
|
||||
for other_rank, is_same_node in enumerate(same_node_flags):
|
||||
if is_same_node and node_assignment[other_rank] == 0:
|
||||
node_assignment[other_rank] = next_node_id
|
||||
|
||||
return next_node_id
|
||||
|
||||
@ -8,6 +8,7 @@ This involves the exchange of expert weights between GPUs.
|
||||
|
||||
from collections.abc import Iterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import (P2POp, ProcessGroup, all_gather,
|
||||
@ -127,6 +128,8 @@ def shuffle_layer(
|
||||
dst_global = local2global(dst)
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
|
||||
continue
|
||||
if old_indices[src_global] == new_indices[dst_global]:
|
||||
is_received_locally[dst] = True
|
||||
for weight, buffer in zip(expert_weights,
|
||||
@ -139,6 +142,8 @@ def shuffle_layer(
|
||||
experts_send_loc: dict[int, int] = {}
|
||||
for src in range(num_local_experts):
|
||||
expert = old_indices[local2global(src)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_send_loc:
|
||||
continue
|
||||
experts_send_loc[expert] = src
|
||||
@ -181,6 +186,8 @@ def shuffle_layer(
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_recv_loc:
|
||||
continue
|
||||
experts_recv_loc[expert] = dst
|
||||
@ -227,6 +234,8 @@ def shuffle_layer(
|
||||
weight[dst].copy_(buffer[dst])
|
||||
else:
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
src = experts_recv_loc[expert]
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[src])
|
||||
@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace(
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace(
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
rank_mapping: A dictionary mapping old rank to new rank.
|
||||
"""
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = \
|
||||
_map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = \
|
||||
_map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[
|
||||
1] == new_global_expert_indices.shape[1]
|
||||
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
|
||||
@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace(
|
||||
)
|
||||
|
||||
|
||||
def _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
new_ep_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the old global expert indices to the new global expert indices.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices:
|
||||
Shape (num_layers, old_ep_size * num_local_physical_experts).
|
||||
rank_mapping: Mapping from old rank to new rank.
|
||||
new_ep_size: New expert parallelism size.
|
||||
|
||||
Returns:
|
||||
Mapped expert indices with shape
|
||||
(num_layers, new_ep_size * num_local_physical_experts).
|
||||
"""
|
||||
num_layers, old_num_physical_experts = old_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
num_local_physical_experts = old_num_physical_experts // old_ep_size
|
||||
new_num_physical_experts = new_ep_size * num_local_physical_experts
|
||||
|
||||
# Create mapped tensor with new shape, initialized to -1
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, new_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=old_global_expert_indices.dtype,
|
||||
device=old_global_expert_indices.device,
|
||||
)
|
||||
|
||||
# Handle rank mapping (scale up/down with rank changes)
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping.get(old_rank)
|
||||
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
|
||||
# This old rank exists in the new configuration
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
|
||||
old_global_expert_indices[:, old_start_idx:old_end_idx]
|
||||
# If new_rank is None or >= new_ep_size, the experts remain -1
|
||||
# (scale down case)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
def _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
) -> torch.Tensor:
|
||||
num_layers, new_num_physical_experts = new_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_local_physical_experts = new_num_physical_experts // new_ep_size
|
||||
old_num_physical_experts = old_ep_size * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, old_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=new_global_expert_indices.dtype,
|
||||
device=new_global_expert_indices.device,
|
||||
)
|
||||
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping[old_rank]
|
||||
if new_rank >= 0 and new_rank < new_ep_size:
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
|
||||
new_global_expert_indices[:, new_start_idx:new_end_idx]
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["rearrange_expert_weights_inplace"]
|
||||
|
||||
@ -324,3 +324,9 @@ class EngineClient(ABC):
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
...
|
||||
|
||||
async def scale_elastic_ep(self,
|
||||
new_data_parallel_size: int,
|
||||
drain_timeout: int = 300) -> None:
|
||||
"""Scale the engine"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1018,6 +1018,73 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
@router.post("/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {
|
||||
"model": dict
|
||||
},
|
||||
HTTPStatus.BAD_REQUEST.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
|
||||
"model": ErrorResponse
|
||||
},
|
||||
})
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="Invalid JSON format") from e # noqa: B904
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="new_data_parallel_size is required")
|
||||
|
||||
if not isinstance(new_data_parallel_size,
|
||||
int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="new_data_parallel_size must be a positive integer")
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="drain_timeout must be a positive integer")
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
global _scaling_elastic_ep
|
||||
_scaling_elastic_ep = True
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse({
|
||||
"message":
|
||||
f"Scaled to {new_data_parallel_size} "
|
||||
"data parallel engines",
|
||||
})
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds") from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})
|
||||
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
@ -1216,6 +1283,41 @@ class XRequestIdMiddleware:
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
# Global variable to track scaling state
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
class ScalingMiddleware:
|
||||
"""
|
||||
Middleware that checks if the model is currently scaling and
|
||||
returns a 503 Service Unavailable response if it is.
|
||||
|
||||
This middleware applies to all HTTP requests and prevents
|
||||
processing when the model is in a scaling state.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
if scope["type"] != "http":
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Check global scaling state
|
||||
global _scaling_elastic_ep
|
||||
if _scaling_elastic_ep:
|
||||
# Return 503 Service Unavailable response
|
||||
response = JSONResponse(content={
|
||||
"error":
|
||||
"The model is currently scaling. Please try again later."
|
||||
},
|
||||
status_code=503)
|
||||
return response(scope, receive, send)
|
||||
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
@ -1404,6 +1506,9 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
if args.enable_request_id_headers:
|
||||
app.add_middleware(XRequestIdMiddleware)
|
||||
|
||||
# Add scaling middleware to check for scaling state
|
||||
app.add_middleware(ScalingMiddleware)
|
||||
|
||||
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
||||
logger.warning("CAUTION: Enabling log response in the API Server. "
|
||||
"This can include sensitive information and should be "
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
run_method)
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -62,6 +63,14 @@ class UniProcExecutor(ExecutorBase):
|
||||
# it's running.
|
||||
return
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
self.driver_worker.reinitialize_distributed(reconfig_request)
|
||||
if reconfig_request.new_data_parallel_rank == \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
|
||||
UniProcExecutorAsync = UniProcExecutor
|
||||
|
||||
|
||||
@ -265,9 +265,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
@ -375,8 +372,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
@ -393,7 +392,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@ -412,6 +416,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
@ -425,7 +433,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return self.rocm_aiter_fused_experts(
|
||||
@ -730,7 +743,8 @@ class FusedMoE(torch.nn.Module):
|
||||
if self.enable_eplb:
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8MoEMethod)
|
||||
if not isinstance(quant_method, Fp8MoEMethod):
|
||||
if not isinstance(quant_method,
|
||||
(Fp8MoEMethod, UnquantizedFusedMoEMethod)):
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
# contain essential differences, but the current quant API
|
||||
@ -821,6 +835,15 @@ class FusedMoE(torch.nn.Module):
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||
|
||||
def update_expert_map(self):
|
||||
# ep_size and ep_rank should already be updated
|
||||
assert self.expert_map is not None
|
||||
with self.expert_map.device:
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
|
||||
@ -776,6 +776,24 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = (num_physical_experts -
|
||||
self.num_logical_experts)
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
moe = layer.mlp
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@ -931,9 +949,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
if (hasattr(config, "num_nextn_predict_layers")
|
||||
and config.num_nextn_predict_layers > 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||
|
||||
@ -543,6 +543,13 @@ class MixtureOfExperts(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
|
||||
return isinstance(model, MixtureOfExperts)
|
||||
|
||||
@ -177,3 +177,19 @@ class EngineCoreRequestType(enum.Enum):
|
||||
UTILITY = b'\x03'
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b'\x04'
|
||||
|
||||
|
||||
class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_size: int
|
||||
new_data_parallel_rank: int
|
||||
new_data_parallel_rank_local: int
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
"""
|
||||
Rank type for reconfiguring distributed request.
|
||||
"""
|
||||
KEEP_CURRENT_RANK = -1
|
||||
SHUTDOWN_CURRENT_RANK = -2
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, Optional, Union
|
||||
@ -608,6 +609,63 @@ class AsyncLLM(EngineClient):
|
||||
return await self.engine_core.collective_rpc_async(
|
||||
method, timeout, args, kwargs)
|
||||
|
||||
async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
|
||||
"""Wait for all requests to be drained."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < drain_timeout:
|
||||
if not self.engine_core.dp_engines_running():
|
||||
logger.info("Engines are idle, requests have been drained")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Engines are still running, waiting for requests to drain...")
|
||||
await asyncio.sleep(1) # Wait 1 second before checking again
|
||||
|
||||
raise TimeoutError(f"Timeout reached after {drain_timeout} seconds "
|
||||
"waiting for requests to drain.")
|
||||
|
||||
async def scale_elastic_ep(self,
|
||||
new_data_parallel_size: int,
|
||||
drain_timeout: int = 300):
|
||||
"""
|
||||
Scale up or down the data parallel size by adding or removing
|
||||
engine cores.
|
||||
Args:
|
||||
new_data_parallel_size: The new number of data parallel workers
|
||||
drain_timeout:
|
||||
Maximum time to wait for requests to drain (seconds)
|
||||
"""
|
||||
old_data_parallel_size = \
|
||||
self.vllm_config.parallel_config.data_parallel_size
|
||||
if old_data_parallel_size == new_data_parallel_size:
|
||||
logger.info("Data parallel size is already %s, skipping scale",
|
||||
new_data_parallel_size)
|
||||
return
|
||||
logger.info(
|
||||
"Waiting for requests to drain before "
|
||||
"scaling up to %s engines...", new_data_parallel_size)
|
||||
await self.wait_for_requests_to_drain(drain_timeout)
|
||||
logger.info(
|
||||
"Requests have been drained, proceeding with scale "
|
||||
"to %s engines", new_data_parallel_size)
|
||||
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
|
||||
self.vllm_config.parallel_config.data_parallel_size = \
|
||||
new_data_parallel_size
|
||||
|
||||
# recreate stat loggers
|
||||
if new_data_parallel_size > old_data_parallel_size:
|
||||
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||
vllm_config=self.vllm_config,
|
||||
log_stats=self.log_stats,
|
||||
engine_num=new_data_parallel_size,
|
||||
custom_stat_loggers=None,
|
||||
)
|
||||
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
|
||||
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
|
||||
else:
|
||||
for _ in range(old_data_parallel_size - new_data_parallel_size):
|
||||
self.stat_loggers.pop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
# Is None before the loop is started.
|
||||
|
||||
@ -200,11 +200,41 @@ class CoordinatorProc:
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
decoded = msgspec.msgpack.decode(buffer)
|
||||
if isinstance(decoded, (list, tuple)) and len(
|
||||
decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP":
|
||||
# Handle scale up notification
|
||||
new_engine_count = decoded[1]
|
||||
current_count = len(self.engines)
|
||||
if new_engine_count > current_count:
|
||||
for _ in range(new_engine_count - current_count):
|
||||
self.engines.append(EngineState())
|
||||
# NOTE(yongji): handle the case
|
||||
# where newly started engines have current_wave = 0
|
||||
# if existing engines just finished a wave
|
||||
# and engine_running isn't updated yet at
|
||||
# CoordinatorProc requests routed to newly started
|
||||
# engines may not wake up existing engines, as long
|
||||
# as 0 < request.wave < existing engines'
|
||||
# current_wave
|
||||
# we note that 0 is the wave number for the new
|
||||
# engine
|
||||
self.engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s "
|
||||
"engines", current_count, new_engine_count)
|
||||
else:
|
||||
self.engines = self.engines[:new_engine_count]
|
||||
logger.info(
|
||||
"DPCoordinator scaled down from %s to %s "
|
||||
"engines", current_count, new_engine_count)
|
||||
continue # Skip normal engine notification processing
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
|
||||
engine_to_exclude, wave = decoded
|
||||
if not self.engines_running:
|
||||
if wave < self.current_wave:
|
||||
# If the wave number is stale, ensure the message
|
||||
|
||||
@ -32,7 +32,9 @@ from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
EngineCoreRequestType,
|
||||
ReconfigureDistributedRequest, ReconfigureRankType,
|
||||
UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
@ -77,6 +79,8 @@ class EngineCore:
|
||||
self.model_executor.register_failure_callback(
|
||||
executor_fail_callback)
|
||||
|
||||
self.available_gpu_memory_for_kv_cache = -1
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
||||
self._initialize_kv_caches(vllm_config)
|
||||
@ -137,12 +141,23 @@ class EngineCore:
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
# Profiles the peak memory usage of the model to determine how much
|
||||
# memory can be allocated for kv cache.
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
available_gpu_memory = \
|
||||
self.model_executor.determine_available_memory()
|
||||
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
assert dp_group is not None
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
||||
available_gpu_memory = [
|
||||
self.available_gpu_memory_for_kv_cache
|
||||
] * len(kv_cache_specs)
|
||||
else:
|
||||
# Profiles the peak memory usage of the model to determine how
|
||||
# much memory can be allocated for kv cache.
|
||||
available_gpu_memory = (
|
||||
self.model_executor.determine_available_memory())
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
available_gpu_memory[0]
|
||||
else:
|
||||
# Attention free models don't need memory for kv cache
|
||||
available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
@ -989,6 +1004,50 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
stateless_destroy_torch_distributed_process_group(self.dp_group)
|
||||
self.shutdown()
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
old_dp_size = parallel_config.data_parallel_size
|
||||
parallel_config.data_parallel_size = \
|
||||
reconfig_request.new_data_parallel_size
|
||||
if reconfig_request.new_data_parallel_rank != -1:
|
||||
parallel_config.data_parallel_rank = \
|
||||
reconfig_request.new_data_parallel_rank
|
||||
# local rank specifies device visibility, it should not be changed
|
||||
assert reconfig_request.new_data_parallel_rank_local == \
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
parallel_config.data_parallel_master_ip = \
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
parallel_config.data_parallel_master_port = \
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
if reconfig_request.new_data_parallel_rank != -2:
|
||||
self.dp_rank = parallel_config.data_parallel_rank
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
reconfig_request.new_data_parallel_master_port = \
|
||||
parallel_config.data_parallel_master_port
|
||||
|
||||
self.model_executor.reinitialize_distributed(reconfig_request)
|
||||
if reconfig_request.new_data_parallel_size > old_dp_size:
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
# pass available_gpu_memory_for_kv_cache from existing
|
||||
# engine-cores to new engine-cores so they can directly
|
||||
# use it in _initialize_kv_caches() rather than profiling.
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.dp_group, self.available_gpu_memory_for_kv_cache)
|
||||
# NOTE(yongji): newly joined workers require dummy_run even
|
||||
# CUDA graph is not used
|
||||
self.model_executor.collective_rpc("compile_or_warm_up_model")
|
||||
if reconfig_request.new_data_parallel_rank == \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||
self.shutdown()
|
||||
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
|
||||
else:
|
||||
logger.info("Distributed environment reinitialized for DP rank %s",
|
||||
self.dp_rank)
|
||||
|
||||
|
||||
class DPEngineCoreActor(DPEngineCoreProc):
|
||||
"""
|
||||
|
||||
@ -21,9 +21,11 @@ import zmq.asyncio
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket
|
||||
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
EngineCoreRequestType,
|
||||
ReconfigureDistributedRequest, ReconfigureRankType,
|
||||
UtilityOutput)
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
@ -162,6 +164,9 @@ class EngineCoreClient(ABC):
|
||||
running state."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -910,14 +915,30 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
events = await poller.poll()
|
||||
if not self.engines_running and len(events) == 2 or (
|
||||
events[0][0] == first_req_rcv_socket):
|
||||
# Send a message to notify the coordinator that
|
||||
# Check if this is a regular request notification or
|
||||
# scale up notification
|
||||
buf = first_req_rcv_socket.recv(
|
||||
flags=zmq.NOBLOCK).result()
|
||||
|
||||
decoded = msgspec.msgpack.decode(buf)
|
||||
if isinstance(
|
||||
decoded,
|
||||
(list, tuple)) and len(decoded) == 2 and decoded[
|
||||
0] == "SCALE_ELASTIC_EP":
|
||||
# Extract new engine count from the decoded message
|
||||
new_engine_count = decoded[1]
|
||||
# Send scale up notification to coordinator
|
||||
scale_msg = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_engine_count))
|
||||
await socket.send(scale_msg)
|
||||
continue
|
||||
|
||||
# we're sending a request while the engines are
|
||||
# paused, so that it can wake the others up
|
||||
# (to run dummy EP loop).
|
||||
assert decoded[0] == "FIRST_REQ"
|
||||
target_eng_index = decoded[1]
|
||||
self.engines_running = True
|
||||
buf = first_req_rcv_socket.recv(
|
||||
flags=zmq.NOBLOCK).result()
|
||||
target_eng_index = int.from_bytes(buf, "little")
|
||||
msg = msgspec.msgpack.encode(
|
||||
(target_eng_index, self.current_wave))
|
||||
await socket.send(msg)
|
||||
@ -953,7 +974,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
chosen_engine)
|
||||
if not self.engines_running:
|
||||
# Notify coordinator that we're sending a request
|
||||
await self.first_req_send_socket.send(chosen_engine)
|
||||
req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine))
|
||||
await self.first_req_send_socket.send(req_msg)
|
||||
|
||||
await to_await
|
||||
|
||||
@ -1047,3 +1069,156 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
engine: EngineIdentity) -> None:
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
||||
engine)
|
||||
|
||||
async def _send_reconfig_message(
|
||||
self, reconfig_request: ReconfigureDistributedRequest,
|
||||
engine: EngineIdentity) -> asyncio.Future:
|
||||
"""Send reconfiguration message and return the result future without
|
||||
waiting for completion."""
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||
(self.client_index, call_id, "reinitialize_distributed",
|
||||
(reconfig_request, ))))
|
||||
await self._send_input_message(message, engine, reconfig_request)
|
||||
self._ensure_output_queue_task()
|
||||
return future
|
||||
|
||||
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
|
||||
"""Scale elastic EP data parallel size"""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
assert new_data_parallel_size != cur_data_parallel_size, (
|
||||
f"new_data_parallel_size {new_data_parallel_size} must be "
|
||||
f"different from cur_data_parallel_size {cur_data_parallel_size}")
|
||||
|
||||
assert self.vllm_config.parallel_config.data_parallel_backend == \
|
||||
"ray", ("Only ray DP backend supports scaling elastic EP")
|
||||
|
||||
scale_up = new_data_parallel_size > cur_data_parallel_size
|
||||
|
||||
if scale_up:
|
||||
await self._scale_up_elastic_ep(cur_data_parallel_size,
|
||||
new_data_parallel_size)
|
||||
else:
|
||||
await self._scale_down_elastic_ep(cur_data_parallel_size,
|
||||
new_data_parallel_size)
|
||||
|
||||
async def _scale_up_elastic_ep(self, cur_data_parallel_size: int,
|
||||
new_data_parallel_size: int) -> None:
|
||||
"""Scale up the data parallel size by creating new engine cores
|
||||
and reconfiguring existing ones."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
# Phase 1: Send reconfigure messages to all existing engines and wait
|
||||
# for them to be sent
|
||||
reconfig_futures = []
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = \
|
||||
get_open_port()
|
||||
for engine in self.core_engines:
|
||||
reconfig_request = ReconfigureDistributedRequest(
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=\
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.
|
||||
data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.
|
||||
data_parallel_master_port)
|
||||
future = await self._send_reconfig_message(reconfig_request,
|
||||
engine)
|
||||
reconfig_futures.append(future)
|
||||
|
||||
logger.info("All reconfigure messages sent, starting engine creation")
|
||||
|
||||
# Phase 2: Create new engines now that reconfig messages have been sent
|
||||
# self.resources.engine_manager is guaranteed to be
|
||||
# CoreEngineActorManager for RayDPClient
|
||||
assert isinstance(self.resources.engine_manager,
|
||||
CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_up_elastic_ep(
|
||||
self.vllm_config, new_data_parallel_size)
|
||||
|
||||
# Create new CoreEngine objects for the new engines
|
||||
new_engine_identities = set()
|
||||
for i in range(cur_data_parallel_size, new_data_parallel_size):
|
||||
new_engine = i.to_bytes(2, "little")
|
||||
self.core_engines.append(new_engine)
|
||||
new_engine_identities.add(new_engine)
|
||||
|
||||
# Wait for ready messages from new engines on the input socket
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
while new_engine_identities:
|
||||
if not sync_input_socket.poll(timeout=600_000):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for new engines to send initial "
|
||||
"message on input socket.")
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
new_engine_identities.discard(identity)
|
||||
|
||||
# Phase 3: Wait for all existing engines to complete reconfiguration
|
||||
logger.info("Waiting for existing engines to complete reconfiguration")
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
# Notify coordinator about scale up through existing
|
||||
# stats_update_task connection
|
||||
self._ensure_stats_update_task()
|
||||
scale_up_marker = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_data_parallel_size))
|
||||
await self.first_req_send_socket.send(scale_up_marker)
|
||||
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = \
|
||||
new_data_parallel_size
|
||||
logger.info(
|
||||
"[Elastic EP] Scale up completed, new data parallel size: %s",
|
||||
new_data_parallel_size)
|
||||
|
||||
async def _scale_down_elastic_ep(self, cur_data_parallel_size: int,
|
||||
new_data_parallel_size: int) -> None:
|
||||
"""Scale down the data parallel size by shutting down and
|
||||
reconfiguring existing engine cores."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = \
|
||||
get_open_port()
|
||||
|
||||
reconfig_futures = []
|
||||
for cur_dp_rank, engine in enumerate(self.core_engines):
|
||||
reconfig_request = ReconfigureDistributedRequest(
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=\
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.
|
||||
data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.
|
||||
data_parallel_master_port)
|
||||
if cur_dp_rank >= new_data_parallel_size:
|
||||
reconfig_request.new_data_parallel_rank = \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
future = await self._send_reconfig_message(reconfig_request,
|
||||
engine)
|
||||
reconfig_futures.append(future)
|
||||
|
||||
for _ in range(new_data_parallel_size, cur_data_parallel_size):
|
||||
self.core_engines.pop()
|
||||
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
assert isinstance(self.resources.engine_manager,
|
||||
CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
cur_data_parallel_size, new_data_parallel_size)
|
||||
|
||||
self._ensure_stats_update_task()
|
||||
scale_down_marker = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_data_parallel_size))
|
||||
await self.first_req_send_socket.send(scale_down_marker)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = \
|
||||
new_data_parallel_size
|
||||
logger.info(
|
||||
"[Elastic EP] Scale down completed, new data parallel size: %s",
|
||||
new_data_parallel_size)
|
||||
|
||||
@ -174,16 +174,21 @@ class CoreEngineActorManager:
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
|
||||
env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor")
|
||||
self.env_vars_dict = {
|
||||
name: os.environ[name]
|
||||
for name in env_vars_list if name in os.environ
|
||||
}
|
||||
runtime_env = RuntimeEnv(env_vars=self.env_vars_dict)
|
||||
|
||||
self.addresses = addresses
|
||||
self.executor_class = executor_class
|
||||
self.log_stats = log_stats
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor")
|
||||
env_vars_dict = {
|
||||
name: os.environ[name]
|
||||
for name in env_vars_set if name in os.environ
|
||||
}
|
||||
runtime_env = RuntimeEnv(env_vars=env_vars_dict)
|
||||
|
||||
if ray.is_initialized():
|
||||
logger.info(
|
||||
@ -208,6 +213,7 @@ class CoreEngineActorManager:
|
||||
assert len(placement_groups) == dp_size, (
|
||||
"Number of placement groups must match data parallel size")
|
||||
|
||||
self.placement_group_is_local = []
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
@ -231,6 +237,7 @@ class CoreEngineActorManager:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
self.placement_group_is_local.append(local_client)
|
||||
refs.append(actor.wait_for_init.remote())
|
||||
|
||||
ray.get(refs)
|
||||
@ -242,6 +249,9 @@ class CoreEngineActorManager:
|
||||
def create_dp_placement_groups(
|
||||
vllm_config: VllmConfig
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
"""
|
||||
Create placement groups for data parallel.
|
||||
"""
|
||||
|
||||
import ray
|
||||
from ray._private.state import available_resources_per_node
|
||||
@ -250,10 +260,11 @@ class CoreEngineActorManager:
|
||||
logger.info("Creating placement groups for data parallel")
|
||||
dp_master_ip = \
|
||||
vllm_config.parallel_config.data_parallel_master_ip
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
@ -293,7 +304,7 @@ class CoreEngineActorManager:
|
||||
local_dp_ranks.append(i)
|
||||
else:
|
||||
for i in range(available_engine_count):
|
||||
if len(placement_groups) == dp_size:
|
||||
if len(placement_groups) == num_pg_to_create:
|
||||
break
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
pg = ray.util.placement_group(
|
||||
@ -305,6 +316,204 @@ class CoreEngineActorManager:
|
||||
local_dp_ranks.append(i)
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
@staticmethod
|
||||
def add_dp_placement_groups(
|
||||
old_vllm_config: VllmConfig, new_data_parallel_size: int
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
"""
|
||||
Add placement groups for new data parallel size.
|
||||
"""
|
||||
import ray
|
||||
from ray._private.state import (available_resources_per_node,
|
||||
total_resources_per_node)
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
old_dp_size = old_vllm_config.parallel_config.data_parallel_size
|
||||
num_pg_to_create = new_data_parallel_size - old_dp_size
|
||||
|
||||
if num_pg_to_create <= 0:
|
||||
return [], []
|
||||
|
||||
dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip
|
||||
world_size = old_vllm_config.parallel_config.world_size
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
"The first node must be the head node")
|
||||
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||
"There can only be one head node")
|
||||
|
||||
available_resources = available_resources_per_node()
|
||||
total_resources = total_resources_per_node()
|
||||
|
||||
placement_groups = []
|
||||
local_dp_ranks = []
|
||||
num_pg_created = 0
|
||||
|
||||
for node in nodes:
|
||||
if num_pg_created >= num_pg_to_create:
|
||||
break
|
||||
|
||||
node_ip = node.node_ip
|
||||
node_id = node.node_id
|
||||
available_gpus = int(available_resources[node_id]["GPU"])
|
||||
|
||||
# Get total GPUs on this node from the node's resources
|
||||
# Ray stores node resources with node ID as key
|
||||
total_gpus = int(total_resources[node_id]["GPU"])
|
||||
|
||||
# Calculate used GPUs and used engines on this node
|
||||
used_gpus = max(0, total_gpus - available_gpus)
|
||||
used_engines_on_node = used_gpus // world_size
|
||||
|
||||
# Calculate how many new engines this node can accommodate
|
||||
available_engine_count = available_gpus // world_size
|
||||
|
||||
# Create placement groups for new engines on this node
|
||||
for i in range(available_engine_count):
|
||||
if num_pg_created >= num_pg_to_create:
|
||||
break
|
||||
|
||||
rank = old_dp_size + num_pg_created
|
||||
|
||||
# Create bundles with node constraint for master node
|
||||
if node_ip == dp_master_ip:
|
||||
bundles = [{
|
||||
"GPU": 1.0,
|
||||
"node:" + dp_master_ip: 0.001
|
||||
}] * world_size + [{
|
||||
"CPU": 1.0
|
||||
}]
|
||||
else:
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{rank}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
|
||||
# Local rank starts from the number of engines already used
|
||||
# on this node
|
||||
local_rank = used_engines_on_node + i
|
||||
local_dp_ranks.append(local_rank)
|
||||
num_pg_created += 1
|
||||
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig,
|
||||
new_data_parallel_size: int) -> None:
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.scheduling_strategies import (
|
||||
PlacementGroupSchedulingStrategy)
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
|
||||
cur_data_parallel_size = len(self.local_engine_actors) + \
|
||||
len(self.remote_engine_actors)
|
||||
|
||||
assert new_data_parallel_size > cur_data_parallel_size, (
|
||||
f"New data parallel size {new_data_parallel_size} must be greater "
|
||||
f"than current data parallel size {cur_data_parallel_size} "
|
||||
"for scale up")
|
||||
|
||||
placement_groups, local_dp_ranks = \
|
||||
self.add_dp_placement_groups(
|
||||
cur_vllm_config, new_data_parallel_size)
|
||||
|
||||
world_size = cur_vllm_config.parallel_config.world_size
|
||||
dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip
|
||||
new_local_engines = 0
|
||||
|
||||
runtime_env = RuntimeEnv(env_vars=self.env_vars_dict
|
||||
| {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"})
|
||||
for i, (pg,
|
||||
local_rank) in enumerate(zip(placement_groups,
|
||||
local_dp_ranks)):
|
||||
rank = cur_data_parallel_size + i
|
||||
dp_vllm_config = copy.deepcopy(cur_vllm_config)
|
||||
dp_vllm_config.parallel_config.data_parallel_size = \
|
||||
new_data_parallel_size
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
|
||||
# Check if this placement group is on the head node
|
||||
local_client = any(
|
||||
bundle.get("node:" + dp_master_ip, 0) > 0
|
||||
for bundle in pg.bundle_specs)
|
||||
|
||||
if local_client:
|
||||
new_local_engines += 1
|
||||
# Update data_parallel_size_local
|
||||
dp_vllm_config.parallel_config.data_parallel_size_local = (
|
||||
cur_vllm_config.parallel_config.data_parallel_size_local +
|
||||
new_local_engines)
|
||||
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_bundle_index=world_size,
|
||||
),
|
||||
runtime_env=runtime_env).remote(
|
||||
vllm_config=dp_vllm_config,
|
||||
executor_class=self.executor_class,
|
||||
log_stats=self.log_stats,
|
||||
local_client=local_client,
|
||||
addresses=self.addresses,
|
||||
dp_rank=rank,
|
||||
local_dp_rank=local_rank)
|
||||
|
||||
if local_client:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
self.created_placement_groups.append(pg)
|
||||
self.placement_group_is_local.append(local_client)
|
||||
|
||||
ray.get([
|
||||
actor.wait_for_init.remote()
|
||||
for actor in (self.local_engine_actors[-new_local_engines:]
|
||||
if new_local_engines > 0 else []) +
|
||||
self.remote_engine_actors[-(len(placement_groups) -
|
||||
new_local_engines):]
|
||||
])
|
||||
|
||||
actors = (self.local_engine_actors[-new_local_engines:]
|
||||
if new_local_engines > 0 else []) + \
|
||||
self.remote_engine_actors[-(len(placement_groups) -
|
||||
new_local_engines):]
|
||||
|
||||
for actor in actors:
|
||||
self.run_refs.append(actor.run.remote())
|
||||
|
||||
cur_vllm_config.parallel_config.data_parallel_size = \
|
||||
new_data_parallel_size
|
||||
# Update old_vllm_config with new data_parallel_size_local if any new
|
||||
# local engines were added
|
||||
if new_local_engines > 0:
|
||||
cur_vllm_config.parallel_config.data_parallel_size_local += \
|
||||
new_local_engines
|
||||
|
||||
def scale_down_elastic_ep(self, cur_data_parallel_size: int,
|
||||
new_data_parallel_size: int) -> None:
|
||||
import ray
|
||||
assert cur_data_parallel_size > new_data_parallel_size, (
|
||||
f"cur_data_parallel_size {cur_data_parallel_size} must be greater "
|
||||
f"than new_data_parallel_size {new_data_parallel_size} "
|
||||
"for scale down")
|
||||
for _ in range(cur_data_parallel_size - new_data_parallel_size):
|
||||
pg = self.created_placement_groups.pop()
|
||||
is_local = self.placement_group_is_local.pop()
|
||||
if is_local:
|
||||
self.local_engine_actors.pop()
|
||||
else:
|
||||
self.remote_engine_actors.pop()
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Union
|
||||
|
||||
from vllm.executor.ray_distributed_executor import ( # noqa
|
||||
RayDistributedExecutor as RayDistributedExecutorV0)
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
@ -62,3 +63,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
# When PP is used, we return a FutureWrapper immediately so that
|
||||
# the scheduler can yield to the next batch.
|
||||
return FutureWrapper(refs[0])
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
self._run_workers("reinitialize_distributed", reconfig_request)
|
||||
if reconfig_request.new_data_parallel_rank == \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
@ -49,7 +49,7 @@ class CPUModelRunner(GPUModelRunner):
|
||||
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
|
||||
replace_tensor(self.input_batch.block_table, k, k[:-4])
|
||||
|
||||
def load_model(self) -> None:
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
|
||||
@ -1745,8 +1745,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
new_config = update_config(config, config_overrides)
|
||||
setattr(self, config_name, new_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
"""
|
||||
Args:
|
||||
eep_scale_up: the model loading is for elastic EP scale up.
|
||||
"""
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
if eep_scale_up:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
num_local_physical_experts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.distributed.broadcast(num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0)
|
||||
num_local_physical_experts = int(num_local_physical_experts.item())
|
||||
new_ep_size = get_ep_group().world_size
|
||||
global_expert_load, old_global_expert_indices = (
|
||||
EplbState.recv_state())
|
||||
num_logical_experts = global_expert_load.shape[1]
|
||||
self.parallel_config.num_redundant_experts = (
|
||||
num_local_physical_experts * new_ep_size - num_logical_experts)
|
||||
assert old_global_expert_indices.shape[
|
||||
1] % num_local_physical_experts == 0
|
||||
old_ep_size = old_global_expert_indices.shape[
|
||||
1] // num_local_physical_experts
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
else:
|
||||
global_expert_load = None
|
||||
old_global_expert_indices = None
|
||||
rank_mapping = None
|
||||
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
time_before_load = time.perf_counter()
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
@ -1788,6 +1820,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model,
|
||||
self.device,
|
||||
self.parallel_config,
|
||||
global_expert_load,
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
def save_tensorized_model(
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
@ -191,8 +192,9 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
@ -384,6 +386,161 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int,
|
||||
new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding "
|
||||
"before scaling down...")
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(self.model_runner.model,
|
||||
execute_shuffle=True,
|
||||
global_expert_load=None,
|
||||
rank_mapping=rank_mapping)
|
||||
torch.cuda.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _eplb_after_scale_up(
|
||||
self, old_ep_size: int, new_ep_size: int,
|
||||
global_expert_load: Optional[torch.Tensor]) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding "
|
||||
"after scaling up...")
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
self.model_runner.model,
|
||||
execute_shuffle=True,
|
||||
global_expert_load=global_expert_load,
|
||||
rank_mapping=rank_mapping)
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _reconfigure_parallel_config(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
"""
|
||||
Update parallel config with provided reconfig_request
|
||||
"""
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = \
|
||||
reconfig_request.new_data_parallel_size
|
||||
if reconfig_request.new_data_parallel_rank != \
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK:
|
||||
parallel_config.data_parallel_rank = \
|
||||
reconfig_request.new_data_parallel_rank
|
||||
if reconfig_request.new_data_parallel_rank_local != \
|
||||
ReconfigureRankType.KEEP_CURRENT_RANK:
|
||||
parallel_config.data_parallel_rank_local = \
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
parallel_config.data_parallel_master_ip = \
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
parallel_config.data_parallel_master_port = \
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
|
||||
def _reconfigure_moe(self, old_ep_size: int,
|
||||
new_ep_size: int) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
Return the global expert load if new_ep_size > old_ep_size,
|
||||
otherwise None
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoEParallelConfig)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
moe_modules = [
|
||||
module for module in self.model_runner.model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules), (
|
||||
"All MoE modules must have the same number of experts")
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
if new_ep_size < old_ep_size:
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = \
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
|
||||
parallel_config.num_redundant_experts = (
|
||||
new_physical_experts -
|
||||
self.model_runner.eplb_state.logical_replica_count.shape[1])
|
||||
global_expert_load = None
|
||||
else:
|
||||
num_local_physical_experts = torch.tensor([num_local_experts],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.distributed.broadcast(num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0)
|
||||
num_local_physical_experts = num_local_physical_experts.item()
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_load = self.model_runner.eplb_state.rearrange(
|
||||
self.model_runner.model, execute_shuffle=False)
|
||||
parallel_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_load.shape[1])
|
||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||
self.model_runner.model.update_physical_experts_metadata(
|
||||
num_physical_experts=new_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts)
|
||||
return global_expert_load
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory, get_ep_group)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
|
||||
).world_size * get_pp_group().world_size
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if reconfig_request.new_data_parallel_rank == \
|
||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
|
||||
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_load is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size,
|
||||
global_expert_load)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user