mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:37:12 +08:00
[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)
Signed-off-by: n00909098 <nguyen.kha.long@huawei.com> Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: herotai214 <herotai214@gmail.com> Signed-off-by: Khuong Le <khuong.le.manh@huawei.com> Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com> Co-authored-by: n00909098 <nguyen.kha.long@huawei.com> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: herotai214 <herotai214@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Khuong Le <khuong.le.manh@huawei.com> Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
parent
cbb799e314
commit
4ccffe561f
BIN
docs/assets/features/disagg_encoder/disagg_encoder_flow.png
Normal file
BIN
docs/assets/features/disagg_encoder/disagg_encoder_flow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
75
docs/features/disagg_encoder.md
Normal file
75
docs/features/disagg_encoder.md
Normal file
@ -0,0 +1,75 @@
|
||||
# Disaggregated Encoder
|
||||
|
||||
A **disaggregated encoder** runs the vision-encoder stage of a multimodal LLM in a process that is separate from the pre-fill / decoder stage. Deploying these two stages in independent vLLM instances brings three practical benefits:
|
||||
|
||||
1. **Independent, fine-grained scaling**
|
||||
2. **Lower time-to-first-token (TTFT)**
|
||||
3. **Cross-process reuse and caching of encoder outputs**
|
||||
|
||||
Design doc: <https://docs.google.com/document/d/1aed8KtC6XkXtdoV87pWT0a8OJlZ-CpnuLLzmR8l9BAE>
|
||||
|
||||
---
|
||||
|
||||
## 1 Motivation
|
||||
|
||||
### 1. Independent, fine-grained scaling
|
||||
|
||||
* Vision encoders are lightweight, while language models are orders of magnitude larger.
|
||||
* The language model can be parallelised without affecting the encoder fleet.
|
||||
* Encoder nodes can be added or removed independently.
|
||||
|
||||
### 2. Lower time-to-first-token (TTFT)
|
||||
|
||||
* Language-only requests bypass the vision encoder entirely.
|
||||
* Encoder output is injected only at required attention layers, shortening the pre-fill critical path.
|
||||
|
||||
### 3. Cross-process reuse and caching
|
||||
|
||||
* In-process encoders confine reuse to a single worker.
|
||||
* A remote, shared cache lets any worker retrieve existing embeddings, eliminating redundant computation.
|
||||
|
||||
---
|
||||
|
||||
## 2 Usage Example
|
||||
|
||||
The current reference pathway is **SharedStorageConnector**.
|
||||
Below ready-to-run scripts shows the workflow:
|
||||
|
||||
1 Encoder instance + 1 PD instance:
|
||||
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_encoder_example.sh`
|
||||
|
||||
1 Encoder instance + 1 Prefill instance + 1 Decode instance:
|
||||
`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_epd_example.sh`
|
||||
|
||||
---
|
||||
|
||||
## 3 Test Script
|
||||
|
||||
Please refer to the directories `tests/v1/ec_connector`
|
||||
|
||||
## 4 Development
|
||||
|
||||
Disaggregated encoding is implemented by running two parts:
|
||||
|
||||
* **Encoder instance** – a vLLM instance to performs vision encoding.
|
||||
* **Prefill/Decode (PD) instance(s)** – runs language pre-fill and decode.
|
||||
* PD can be in either a single normal instance with `disagg_encoder_example.sh` (E->PD) or in disaggregated instances with `disagg_epd_example.sh` (E->P->D)
|
||||
|
||||
A connector transfers encoder-cache (EC) embeddings from the encoder instance to the PD instance.
|
||||
All related code is under `vllm/distributed/ec_transfer`.
|
||||
|
||||
### Key abstractions
|
||||
|
||||
* **ECConnector** – interface for retrieving EC caches produced by the encoder.
|
||||
* *Scheduler role* – checks cache existence and schedules loads.
|
||||
* *Worker role* – loads the embeddings into memory.
|
||||
|
||||
Here is a figure illustrating disaggregate encoder flow:
|
||||
|
||||

|
||||
|
||||
For the PD disaggregation part, the Prefill instance receive cache exactly the same as the disaggregate encoder flow above. Prefill instance executes 1 step (prefill -> 1 token output) and then transfer KV cache to the Decode instance for the remaining execution. The KV transfer part purely happens after the execute of the PDinstance.
|
||||
|
||||
`docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0)
|
||||
|
||||
We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D;
|
||||
119
examples/online_serving/disaggregated_encoder/README.md
Normal file
119
examples/online_serving/disaggregated_encoder/README.md
Normal file
@ -0,0 +1,119 @@
|
||||
# Disaggregated Encoder
|
||||
|
||||
These example scripts that demonstrate the disaggregated encoder (EPD) features of vLLM.
|
||||
|
||||
For a detailed explanation of the EPD features, please refer to the [Disaggregated Encoder Feature Documentation](../../../docs/features/disagg_encoder.md).
|
||||
|
||||
## Files
|
||||
|
||||
- `disagg_epd_proxy.py` - Proxy script that demonstrates the XeYpZd setup (X encode instances, Y prefill instances, Z decode instances). Currently stable for the 1e1p1d configuration.
|
||||
|
||||
- `disagg_1e1p1d_example.sh` - Sets up the 1e1p1d configuration, runs the VisionArena benchmark, and processes a single request with a local image.
|
||||
|
||||
- `disagg_1e1pd_example.sh` - Sets up the 1e1pd configuration, runs the VisionArena benchmark, and processes a single request with a local image.
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```bash
|
||||
# Use specific GPUs
|
||||
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific ports
|
||||
ENDPOINT_PORT=10001 bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific model
|
||||
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh
|
||||
|
||||
# Use specific storage path
|
||||
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh
|
||||
```
|
||||
|
||||
## Encoder Instances
|
||||
|
||||
Encoder engines should be launched with the following flags:
|
||||
|
||||
- `--enforce-eager` **(required)** – The current EPD implementation is only compatible with encoder instances running in this mode.
|
||||
|
||||
- `--no-enable-prefix-caching` **(required)** – Encoder instances do not consume KV cache; prefix caching is disabled to avoid conflicts with other features.
|
||||
|
||||
- `--max-num-batched-tokens=<large value>` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager.
|
||||
|
||||
## Local media inputs
|
||||
|
||||
To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance:
|
||||
|
||||
```bash
|
||||
--allowed-local-media-path $MEDIA_PATH
|
||||
```
|
||||
|
||||
The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url": "file://'"$MEDIA_PATH_FILENAME"'}``` as multimodal inputs. Each URI is passed unchanged from the `disagg_encoder_proxy` to the encoder instance so that the encoder can load the media locally.
|
||||
|
||||
## EC connector and KV transfer
|
||||
|
||||
The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration:
|
||||
|
||||
```bash
|
||||
# Add to encoder instance:
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}'
|
||||
|
||||
# Add to prefill/prefill+decode instance:
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
`$EC_SHARED_STORAGE_PATH` is the path where the EC connector temporarily stores the cache.
|
||||
|
||||
If you enable prefill instance (`--prefill-servers-urls` not disabled), you will need --kv-transfer-config to facilitate the PD disaggregation. Currently, we use the `NixlConnector` for this purpose. Refer to `tests/v1/kv_connector/nixl_integration` for more example codes on PD disaggregation with Nixl.
|
||||
|
||||
```bash
|
||||
# Add to prefill instance:
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}'
|
||||
|
||||
# Add to decode instance:
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}'
|
||||
```
|
||||
|
||||
## Proxy Instance Flags (`disagg_epd_proxy.py`)
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--encode-servers-urls` | Comma-separated list of encoder endpoints. Every multimodal item extracted from the request is fanned out to one of these URLs in a round-robin fashion. |
|
||||
| `--prefill-servers-urls` | Comma-separated list of prefill endpoints. Set to `disable`, `none`, or `""` to skip the dedicated prefill phase and run E+PD (encoder + combined prefill/decode). |
|
||||
| `--decode-servers-urls` | Comma-separated list of decode endpoints. Non-stream and stream paths both round-robin over this list. |
|
||||
| `--host`, `--port` | Bind address for the proxy itself (defaults: `0.0.0.0:8000`). |
|
||||
|
||||
Example usage:
|
||||
For E + PD setup:
|
||||
|
||||
```bash
|
||||
$ python disagg_encoder_proxy.py \
|
||||
--encode-servers-urls "http://e1:8001,http://e2:8002" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://pd1:8003,http://pd2:8004"
|
||||
```
|
||||
|
||||
For E + P + D setup:
|
||||
|
||||
```bash
|
||||
$ python disagg_encoder_proxy.py \
|
||||
--encode-servers-urls "http://e1:8001,http://e2:8001" \
|
||||
--prefill-servers-urls "http://p1:8003,http://p2:8004" \
|
||||
--decode-servers-urls "http://d1:8005,http://d2:8006"
|
||||
```
|
||||
@ -0,0 +1,221 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
###############################################################################
|
||||
# Configuration -- override via env before running
|
||||
###############################################################################
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
|
||||
LOG_PATH="${LOG_PATH:-./logs}"
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_PORT="${PREFILL_PORT:-19535}"
|
||||
DECODE_PORT="${DECODE_PORT:-19536}"
|
||||
PROXY_PORT="${PROXY_PORT:-10001}"
|
||||
|
||||
GPU_E="${GPU_E:-2}"
|
||||
GPU_P="${GPU_P:-2}"
|
||||
GPU_D="${GPU_D:-3}"
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
export UCX_TLS=all
|
||||
export UCX_NET_DEVICES=all
|
||||
|
||||
###############################################################################
|
||||
# Helpers
|
||||
###############################################################################
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
START_TIME=$(date +"%Y%m%d_%H%M%S")
|
||||
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
|
||||
P_LOG=$LOG_PATH/p_${START_TIME}.log
|
||||
D_LOG=$LOG_PATH/d_${START_TIME}.log
|
||||
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM USR1 # prevent re-entrancy
|
||||
|
||||
# Kill all tracked PIDs
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Killing process $pid"
|
||||
kill "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Wait a moment for graceful shutdown
|
||||
sleep 2
|
||||
|
||||
# Force kill any remaining processes
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Force killing process $pid"
|
||||
kill -9 "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Kill the entire process group as backup
|
||||
kill -- -$$ 2>/dev/null
|
||||
|
||||
echo "All processes stopped."
|
||||
exit 0
|
||||
}
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
# clear previous cache
|
||||
echo "remove previous ec cache folder"
|
||||
rm -rf $EC_SHARED_STORAGE_PATH
|
||||
|
||||
echo "make ec cache folder"
|
||||
mkdir -p $EC_SHARED_STORAGE_PATH
|
||||
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${ENC_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Prefill worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$PREFILL_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}' \
|
||||
>"${P_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Decode worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}' \
|
||||
>"${D_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for workers
|
||||
wait_for_server $ENCODE_PORT
|
||||
wait_for_server $PREFILL_PORT
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
###############################################################################
|
||||
# Proxy
|
||||
###############################################################################
|
||||
python disagg_epd_proxy.py \
|
||||
--host "0.0.0.0" \
|
||||
--port "$PROXY_PORT" \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
|
||||
--decode-servers-urls "http://localhost:$DECODE_PORT" \
|
||||
>"${PROXY_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
wait_for_server $PROXY_PORT
|
||||
echo "All services are up!"
|
||||
|
||||
###############################################################################
|
||||
# Benchmark
|
||||
###############################################################################
|
||||
echo "Running benchmark (stream)..."
|
||||
vllm bench serve \
|
||||
--model $MODEL \
|
||||
--backend openai-chat \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--seed 0 \
|
||||
--num-prompts $NUM_PROMPTS \
|
||||
--port $PROXY_PORT
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Single request with local image
|
||||
###############################################################################
|
||||
echo "Running single request with local image (non-stream)..."
|
||||
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'${MODEL}'",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
|
||||
{"type": "text", "text": "What is in this image?"}
|
||||
]}
|
||||
]
|
||||
}'
|
||||
|
||||
|
||||
# cleanup
|
||||
echo "cleanup..."
|
||||
cleanup
|
||||
@ -0,0 +1,186 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
###############################################################################
|
||||
# Configuration -- override via env before running
|
||||
###############################################################################
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
|
||||
LOG_PATH="${LOG_PATH:-./logs}"
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}"
|
||||
PROXY_PORT="${PROXY_PORT:-10001}"
|
||||
|
||||
GPU_E="${GPU_E:-0}"
|
||||
GPU_PD="${GPU_PD:-1}"
|
||||
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout
|
||||
|
||||
NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark
|
||||
|
||||
###############################################################################
|
||||
# Helpers
|
||||
###############################################################################
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
START_TIME=$(date +"%Y%m%d_%H%M%S")
|
||||
ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log
|
||||
PD_LOG=$LOG_PATH/pd_${START_TIME}.log
|
||||
PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:$port/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM USR1 # prevent re-entrancy
|
||||
|
||||
# Kill all tracked PIDs
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Killing process $pid"
|
||||
kill "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Wait a moment for graceful shutdown
|
||||
sleep 2
|
||||
|
||||
# Force kill any remaining processes
|
||||
for pid in "${PIDS[@]}"; do
|
||||
if kill -0 "$pid" 2>/dev/null; then
|
||||
echo "Force killing process $pid"
|
||||
kill -9 "$pid" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
|
||||
# Kill the entire process group as backup
|
||||
kill -- -$$ 2>/dev/null
|
||||
|
||||
echo "All processes stopped."
|
||||
exit 0
|
||||
}
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
# clear previous cache
|
||||
echo "remove previous ec cache folder"
|
||||
rm -rf $EC_SHARED_STORAGE_PATH
|
||||
|
||||
echo "make ec cache folder"
|
||||
mkdir -p $EC_SHARED_STORAGE_PATH
|
||||
|
||||
###############################################################################
|
||||
# Encoder worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--port "$ENCODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${ENC_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Prefill+Decode worker
|
||||
###############################################################################
|
||||
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--port "$PREFILL_DECODE_PORT" \
|
||||
--enforce-eager \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
>"${PD_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for workers
|
||||
wait_for_server $ENCODE_PORT
|
||||
wait_for_server $PREFILL_DECODE_PORT
|
||||
|
||||
###############################################################################
|
||||
# Proxy
|
||||
###############################################################################
|
||||
python disagg_epd_proxy.py \
|
||||
--host "0.0.0.0" \
|
||||
--port "$PROXY_PORT" \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
|
||||
>"${PROXY_LOG}" 2>&1 &
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
wait_for_server $PROXY_PORT
|
||||
echo "All services are up!"
|
||||
|
||||
###############################################################################
|
||||
# Benchmark
|
||||
###############################################################################
|
||||
echo "Running benchmark (stream)..."
|
||||
vllm bench serve \
|
||||
--model $MODEL \
|
||||
--backend openai-chat \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--seed 0 \
|
||||
--num-prompts $NUM_PROMPTS \
|
||||
--port $PROXY_PORT
|
||||
|
||||
PIDS+=($!)
|
||||
|
||||
###############################################################################
|
||||
# Single request with local image
|
||||
###############################################################################
|
||||
echo "Running single request with local image (non-stream)..."
|
||||
curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'${MODEL}'",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}},
|
||||
{"type": "text", "text": "What is in this image?"}
|
||||
]}
|
||||
]
|
||||
}'
|
||||
|
||||
|
||||
# cleanup
|
||||
echo "cleanup..."
|
||||
cleanup
|
||||
@ -0,0 +1,606 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
disagg_encoder_proxy.py
|
||||
|
||||
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
|
||||
clusters:
|
||||
• encode (multimodal feature extraction)
|
||||
• decode (language-model inference)
|
||||
|
||||
For MM input we:
|
||||
1. Extract *every* image/audio item.
|
||||
2. Fire N concurrent requests to the encoder cluster
|
||||
(one request per item, with **all text removed**).
|
||||
3. Wait for all of them to succeed.
|
||||
4. Forward the *original* request to a decode server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
###############################################################################
|
||||
# FastAPI app & global state
|
||||
###############################################################################
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("proxy")
|
||||
|
||||
app = FastAPI()
|
||||
encode_session: aiohttp.ClientSession | None = None
|
||||
prefill_session: aiohttp.ClientSession | None = None
|
||||
decode_session: aiohttp.ClientSession | None = None
|
||||
|
||||
###############################################################################
|
||||
# Utils
|
||||
###############################################################################
|
||||
|
||||
|
||||
MM_TYPES = {"image_url", "audio_url", "input_audio"}
|
||||
|
||||
|
||||
def extract_mm_items(request_data: dict) -> list[dict]:
|
||||
"""
|
||||
Return *all* image/audio items that appear anywhere in `messages`.
|
||||
|
||||
Each returned dict looks like:
|
||||
{ "type": "image_url", "image_url": {...} }
|
||||
"""
|
||||
items: list[dict] = []
|
||||
for msg in request_data.get("messages", []):
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if item.get("type") in MM_TYPES:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
async def fanout_encoder_primer(
|
||||
orig_request: dict,
|
||||
e_urls: list[str],
|
||||
req_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
1. Build one request *per MM item* with all text removed.
|
||||
2. Send them concurrently to the encode cluster.
|
||||
3. Raise if any of them fails.
|
||||
"""
|
||||
logger.info("[%s] Processing multimodal items...", req_id)
|
||||
|
||||
mm_items = extract_mm_items(orig_request)
|
||||
if not mm_items:
|
||||
logger.info("[%s] No multimodal items, skipping encoder", req_id)
|
||||
return # nothing to do
|
||||
|
||||
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
|
||||
|
||||
tasks = []
|
||||
|
||||
# Round-robin over encode servers to distribute load a bit
|
||||
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
|
||||
|
||||
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
|
||||
# Derive a *child* request id: <parent>:<index>:<random-short>
|
||||
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
|
||||
headers = {"x-request-id": child_req_id}
|
||||
|
||||
encoder_req = {
|
||||
# You *may* need to keep additional fields
|
||||
"model": orig_request.get("model"),
|
||||
"messages": [
|
||||
{"role": "user", "content": [item]},
|
||||
],
|
||||
# Only need 1 token so the server actually runs the encoder path
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
tasks.append(
|
||||
encode_session.post(
|
||||
f"{target_url}/v1/chat/completions",
|
||||
json=encoder_req,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Fail fast if any sub-request failed
|
||||
for idx, r in enumerate(results):
|
||||
if isinstance(r, Exception):
|
||||
logger.error(
|
||||
"[%s] Encoder request #%d raised exception: %s",
|
||||
req_id,
|
||||
idx,
|
||||
r,
|
||||
exc_info=r,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"Encoder request failed: {str(r)}"
|
||||
)
|
||||
if r.status != 200:
|
||||
try:
|
||||
detail = await r.text()
|
||||
except Exception:
|
||||
detail = "<unable to read body>"
|
||||
logger.error(
|
||||
"[%s] Encoder request #%d returned status %s: %s",
|
||||
req_id,
|
||||
idx,
|
||||
r.status,
|
||||
detail,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=r.status,
|
||||
detail=f"Encoder request failed: {detail}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[%s] All %d encoder requests completed successfully", req_id, len(mm_items)
|
||||
)
|
||||
|
||||
|
||||
async def maybe_prefill(
|
||||
req_data: dict,
|
||||
p_url: str,
|
||||
req_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
- Do prefill-only task if p_url exist;
|
||||
- Return modified request data with kv transfer params (for nixl connector)
|
||||
- Else, skip and return the original request data for decode
|
||||
"""
|
||||
if p_url:
|
||||
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
|
||||
|
||||
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
|
||||
# for nixl connector to facilitate kv transfer...
|
||||
prefill_response_json = await prefill_response.json()
|
||||
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
|
||||
return req_data
|
||||
else:
|
||||
return req_data
|
||||
|
||||
|
||||
async def process_prefill_stage(
|
||||
req_data: dict,
|
||||
p_url: str,
|
||||
req_id: str,
|
||||
) -> dict:
|
||||
"""Process request through Prefill stage and return kv_transfer_params"""
|
||||
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
|
||||
|
||||
prefill_request = req_data.copy()
|
||||
prefill_request["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
prefill_request["stream"] = False
|
||||
prefill_request["max_tokens"] = 1
|
||||
if "max_completion_tokens" in prefill_request:
|
||||
prefill_request["max_completion_tokens"] = 1
|
||||
if "stream_options" in prefill_request:
|
||||
del prefill_request["stream_options"]
|
||||
|
||||
headers = {"x-request-id": req_id}
|
||||
try:
|
||||
prefill_response = await prefill_session.post(
|
||||
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
|
||||
)
|
||||
prefill_response.raise_for_status()
|
||||
|
||||
if prefill_response.status != 200:
|
||||
error_text = await prefill_response.text()
|
||||
logger.error(
|
||||
"[%s] Prefill request failed with status %d: %s",
|
||||
req_id,
|
||||
prefill_response.status,
|
||||
error_text,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=prefill_response.status,
|
||||
detail={"error": "Prefill request failed", "message": error_text},
|
||||
)
|
||||
logger.info("[%s] Prefill request completed successfully", req_id)
|
||||
|
||||
return prefill_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Prefill processing failed: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Prefill processing error", "message": str(e)},
|
||||
) from e
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Middleware for request/response logging
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Middleware to log all incoming requests and responses"""
|
||||
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
|
||||
# Log incoming request
|
||||
logger.info(
|
||||
">>> [%s] %s %s from %s",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
request.client.host if request.client else "unknown",
|
||||
)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"<<< [%s] %s %s completed with status %d",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
# Log errors
|
||||
logger.exception(
|
||||
"!!! [%s] %s %s failed with error: %s",
|
||||
req_id,
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
###############################################################################
|
||||
# FastAPI lifecycle
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup() -> None:
|
||||
global encode_session, prefill_session, decode_session
|
||||
timeout = aiohttp.ClientTimeout(total=100_000)
|
||||
connector = aiohttp.TCPConnector(limit=0, force_close=False)
|
||||
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
if app.state.p_urls:
|
||||
# only setup if prefill instance(s) exist
|
||||
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown() -> None:
|
||||
global encode_session, prefill_session, decode_session
|
||||
if encode_session:
|
||||
await encode_session.close()
|
||||
if prefill_session:
|
||||
await prefill_session.close()
|
||||
if decode_session:
|
||||
await decode_session.close()
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Core forwarding
|
||||
###############################################################################
|
||||
|
||||
|
||||
async def forward_non_stream(
|
||||
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
|
||||
) -> dict:
|
||||
try:
|
||||
# Step 1: Process through Encoder instance (if has MM input)
|
||||
await fanout_encoder_primer(req_data, e_urls, req_id)
|
||||
|
||||
# Step 2: Process through Prefill instance
|
||||
req_data = await maybe_prefill(req_data, p_url, req_id)
|
||||
|
||||
# Step 3: Process through Decode instance
|
||||
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
|
||||
headers = {"x-request-id": req_id}
|
||||
|
||||
# Non-streaming response
|
||||
async with decode_session.post(
|
||||
f"{d_url}/v1/chat/completions", json=req_data, headers=headers
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
|
||||
|
||||
|
||||
async def forward_stream(
|
||||
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
|
||||
) -> AsyncIterator[str]:
|
||||
try:
|
||||
# Step 1: Process through Encoder instance (if has MM input)
|
||||
await fanout_encoder_primer(req_data, e_urls, req_id)
|
||||
|
||||
# Step 2: Process through Prefill instance
|
||||
req_data = await maybe_prefill(req_data, p_url, req_id)
|
||||
|
||||
# Step 3: Process through Decode instance
|
||||
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
|
||||
headers = {"x-request-id": req_id}
|
||||
|
||||
# Streaming response
|
||||
async with decode_session.post(
|
||||
f"{d_url}/v1/chat/completions",
|
||||
json=req_data,
|
||||
headers=headers,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.content.iter_chunked(1024):
|
||||
if chunk:
|
||||
yield chunk.decode("utf-8", errors="ignore")
|
||||
|
||||
logger.info("[%s] Streaming completed", req_id)
|
||||
|
||||
except HTTPException:
|
||||
logger.exception("[%s] HTTPException in forward_stream", req_id)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Proxy streaming error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Public routes
|
||||
###############################################################################
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
||||
|
||||
e_urls = app.state.e_urls # we want the full list for fan-out
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
|
||||
is_streaming = req_data.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
forward_stream(req_data, req_id, e_urls, p_url, d_url),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
result = await forward_non_stream(req_data, req_id, e_urls, p_url, d_url)
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat_completions endpoint: %s", str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Request processing error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
async def healthy(urls):
|
||||
if not urls:
|
||||
return "empty"
|
||||
for u in urls:
|
||||
try:
|
||||
async with encode_session.get(f"{u}/health") as resp:
|
||||
resp.raise_for_status()
|
||||
except Exception:
|
||||
return "unhealthy"
|
||||
return "healthy"
|
||||
|
||||
e_status, p_status, d_status = await asyncio.gather(
|
||||
healthy(app.state.e_urls), healthy(app.state.p_urls), healthy(app.state.d_urls)
|
||||
)
|
||||
|
||||
overall_healthy = all(
|
||||
status != "unhealthy" for status in (e_status, p_status, d_status)
|
||||
)
|
||||
|
||||
status_code = 200 if overall_healthy else 503
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"proxy": "healthy",
|
||||
"encode_cluster": e_status,
|
||||
"prefill_cluster": p_status,
|
||||
"decode_cluster": d_status,
|
||||
},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Simple profiler fan-out (unchanged except for sessions)
|
||||
###############################################################################
|
||||
|
||||
|
||||
async def _post_if_available(
|
||||
session: aiohttp.ClientSession,
|
||||
url: str,
|
||||
payload: dict,
|
||||
headers: dict,
|
||||
) -> dict | None:
|
||||
"""
|
||||
POST `payload` to `url`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
• The decoded JSON body on success (2xx)
|
||||
• None if the endpoint does not exist (404)
|
||||
• Raises for anything else.
|
||||
"""
|
||||
try:
|
||||
resp = await session.post(url, json=payload, headers=headers)
|
||||
if resp.status == 404: # profiling disabled on that server
|
||||
logger.warning("Profiling endpoint missing on %s", url)
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
return await resp.json(content_type=None)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
# Pass 404 through the branch above, re-raise everything else
|
||||
if exc.status == 404:
|
||||
logger.warning("Profiling endpoint missing on %s", url)
|
||||
return None
|
||||
raise
|
||||
except Exception:
|
||||
# Network errors etc.: propagate
|
||||
raise
|
||||
|
||||
|
||||
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
|
||||
"""
|
||||
Fire & forget to both clusters, tolerate 404.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
|
||||
|
||||
encode_task = _post_if_available(
|
||||
encode_session, f"{e_url}/{cmd}_profile", payload, headers
|
||||
)
|
||||
prefill_task = (
|
||||
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
|
||||
if p_url is not None
|
||||
else asyncio.sleep(0)
|
||||
)
|
||||
decode_task = _post_if_available(
|
||||
decode_session, f"{d_url}/{cmd}_profile", payload, headers
|
||||
)
|
||||
|
||||
encode_res, prefill_res, decode_res = await asyncio.gather(
|
||||
encode_task, prefill_task, decode_task
|
||||
)
|
||||
|
||||
# If *all* clusters said “I don’t have that route”, surface an error
|
||||
if encode_res is prefill_res is decode_res is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Profiling endpoints are disabled on all clusters",
|
||||
)
|
||||
|
||||
return {
|
||||
"encode": encode_res, # may be None
|
||||
"prefill": prefill_res, # may be None
|
||||
"decode": decode_res, # may be None
|
||||
}
|
||||
|
||||
|
||||
@app.post("/start_profile")
|
||||
async def start_profile(request: Request):
|
||||
body = await request.json()
|
||||
# TODO: handle multi urls properly
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
return await _profile_cmd("start", body, e_url, p_url, d_url)
|
||||
|
||||
|
||||
@app.post("/stop_profile")
|
||||
async def stop_profile(request: Request):
|
||||
body = await request.json()
|
||||
# TODO: handle multi urls properly
|
||||
e_url = random.choice(app.state.e_urls)
|
||||
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
||||
d_url = random.choice(app.state.d_urls)
|
||||
return await _profile_cmd("stop", body, e_url, p_url, d_url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--encode-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-servers-urls",
|
||||
required=True,
|
||||
help=(
|
||||
'Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") ',
|
||||
'to enable E->P->D, set "disable" or "none" to enable E->PD',
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-servers-urls",
|
||||
required=True,
|
||||
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
app.state.e_urls = [
|
||||
u.strip() for u in args.encode_servers_urls.split(",") if u.strip()
|
||||
]
|
||||
app.state.d_urls = [
|
||||
u.strip() for u in args.decode_servers_urls.split(",") if u.strip()
|
||||
]
|
||||
# handle prefill instances
|
||||
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
|
||||
app.state.p_urls = []
|
||||
logger.info(
|
||||
"Disaggregated prefill phase explicitly disabled by user. Running E + PD..."
|
||||
)
|
||||
else:
|
||||
app.state.p_urls = [
|
||||
u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()
|
||||
]
|
||||
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
|
||||
|
||||
logger.info("Proxy listening on %s:%s", args.host, args.port)
|
||||
logger.info("Encode servers: %s", app.state.e_urls)
|
||||
logger.info("Prefill instances %s", app.state.p_urls)
|
||||
logger.info("Decode servers: %s", app.state.d_urls)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
loop="uvloop",
|
||||
access_log=True,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ECTransferConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
@ -46,6 +47,8 @@ def create_scheduler(
|
||||
num_speculative_tokens: int | None = None,
|
||||
skip_tokenizer_init: bool = False,
|
||||
async_scheduling: bool = False,
|
||||
use_ec_connector: bool = False,
|
||||
ec_role: str | None = None,
|
||||
) -> Scheduler | AsyncScheduler:
|
||||
"""Create scheduler under test.
|
||||
|
||||
@ -107,12 +110,23 @@ def create_scheduler(
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens
|
||||
)
|
||||
|
||||
ec_transfer_config = (
|
||||
ECTransferConfig(
|
||||
ec_connector="ECSharedStorageConnector",
|
||||
ec_role=ec_role,
|
||||
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
|
||||
)
|
||||
if use_ec_connector
|
||||
else None
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
ec_transfer_config=ec_transfer_config,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
@ -140,12 +154,14 @@ _none_hash_initialized = False
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_hashes_list: list[list[str]] | None = None,
|
||||
mm_positions: list[list[PlaceholderRange]] | None = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
same_prompt: bool = False,
|
||||
block_size: int = 16,
|
||||
req_ids: list[str] | None = None,
|
||||
) -> list[Request]:
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
@ -160,25 +176,58 @@ def create_requests(
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
requests = []
|
||||
|
||||
if mm_hashes_list is not None:
|
||||
# NOTE: allow manual input; some mm items can have the same identifier
|
||||
# no. of mm_hashes and mm_positions for each request should be identical
|
||||
assert mm_positions is not None, (
|
||||
"mm_positions must be provided when mm_hashes_list is provided"
|
||||
)
|
||||
assert len(mm_hashes_list) == len(mm_positions) == num_requests
|
||||
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
|
||||
|
||||
# Since same identifier would imply they are identical encoder output
|
||||
# Verify mm items with identical identifier are having mm_position.length
|
||||
seen_hashes: dict[str, int] = {}
|
||||
|
||||
if req_ids:
|
||||
assert len(req_ids) == num_requests
|
||||
else:
|
||||
req_ids = [f"{i}" for i in range(num_requests)]
|
||||
|
||||
for i in range(num_requests):
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
for j, position in enumerate(mm_position):
|
||||
# Dummy hash for each mm item should be unique
|
||||
# since encoder cache tracks entries by hash
|
||||
|
||||
for j, position in enumerate(
|
||||
mm_positions[i] if mm_positions is not None else []
|
||||
):
|
||||
if mm_hashes_list is not None:
|
||||
identifier = mm_hashes_list[i][j]
|
||||
|
||||
# Verify if position length is identical
|
||||
position_length = position.length
|
||||
if identifier in seen_hashes:
|
||||
assert seen_hashes[identifier] == position_length, (
|
||||
f"mm_hash '{identifier}' has inconsistent position lengths: "
|
||||
f"previously {seen_hashes[identifier]}, now {position_length} "
|
||||
f"at request {i}, position {j}"
|
||||
)
|
||||
else:
|
||||
seen_hashes[identifier] = position_length
|
||||
else:
|
||||
# Unique dummy hash for each mm item
|
||||
identifier = f"hash{i}_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image",
|
||||
)
|
||||
mm_features.append(mm_feature)
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image",
|
||||
)
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
request_id=req_ids[i],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
|
||||
171
tests/v1/ec_connector/integration/README.md
Normal file
171
tests/v1/ec_connector/integration/README.md
Normal file
@ -0,0 +1,171 @@
|
||||
# EPD Correctness Test
|
||||
|
||||
This test verifies that EPD (Encoder-Prefill-Decode) disaggregation produces identical outputs to a baseline single instance.
|
||||
|
||||
## What It Tests
|
||||
|
||||
- **Baseline**: Single vLLM instance serving a multimodal model
|
||||
- **EPD (1E+1PD)**: 1 Encoder + 1 Prefill-Decode instance
|
||||
- **Baseline (1P+1D)**: 1 Prefill + 1 Decode instance
|
||||
- **EPD (1E+1P+1D)**: 1 Encoder + 1 Prefill + 1 Decode instance
|
||||
|
||||
The test ensures that disaggregated encoding produces **identical** outputs to the baseline.
|
||||
|
||||
Note that currently PD disaggregation set up may give slightly different results from a single instance. Therefore, we need the result from 1P+1D as the baseline for 1E+1P+1D
|
||||
|
||||
Please refer to [Disaggregated Encoder Feature](../../../docs/features/disagg_encoder.md) for the detailed explanation for the EPD features.
|
||||
|
||||
## Files
|
||||
|
||||
- `run_epd_correctness_test.sh` - Main test script (starts all instances and runs tests)
|
||||
- `test_epd_correctness.py` - Python test script (compares outputs)
|
||||
|
||||
## Usage
|
||||
|
||||
### Multimodal Prompts (Default)
|
||||
|
||||
```bash
|
||||
cd vllm
|
||||
./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
```
|
||||
|
||||
This runs the test with actual multimodal (image) prompts.
|
||||
|
||||
### Text-Only Prompts
|
||||
|
||||
```bash
|
||||
cd vllm
|
||||
USE_MM_PROMPTS=0 ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
```
|
||||
|
||||
This runs a quick test with text-only prompts to verify the setup works.
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```bash
|
||||
# Use specific GPUs
|
||||
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
|
||||
# Use specific ports
|
||||
ENDPOINT_PORT=10001 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
|
||||
# Use specific model
|
||||
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
|
||||
# Use specific storage path
|
||||
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Step 1: Baseline
|
||||
|
||||
1. Start single vLLM instance on GPU
|
||||
2. Run test prompts (multimodal or text-only)
|
||||
3. Save outputs to `.vllm_epd_baseline.txt`
|
||||
4. Shutdown instance
|
||||
|
||||
### Step 2: EPD (1E + 1PD)
|
||||
|
||||
1. Clear encoder cache storage
|
||||
2. Start instances and proxy
|
||||
3. Run same test prompts
|
||||
4. Assert outputs match baseline exactly
|
||||
5. Shutdown instances
|
||||
|
||||
### Step 3: EPD (1E + 1P + 1D)
|
||||
|
||||
1. Clear encoder cache storage
|
||||
2. Start instances and proxy
|
||||
3. Run same test prompts
|
||||
4. Assert outputs match baseline exactly
|
||||
5. Shutdown instances
|
||||
|
||||
## Test Scenarios
|
||||
|
||||
### Multimodal Prompts (--use_mm_prompts)
|
||||
|
||||
Tests encoder cache transfer:
|
||||
|
||||
- Single image query
|
||||
- Multiple images in one request
|
||||
- Mixed image and text
|
||||
- Image with detailed questions
|
||||
|
||||
### Text-Only Prompts (default)
|
||||
|
||||
Quick sanity check:
|
||||
|
||||
- Simple text queries
|
||||
- Text-only explanations
|
||||
- Verifies proxy routing works
|
||||
|
||||
## Expected Behavior
|
||||
|
||||
### ✅ Test Passes When
|
||||
|
||||
- All disagg outputs match baseline outputs exactly
|
||||
- No errors during instance startup
|
||||
- Encoder cache is properly saved and loaded
|
||||
- Proxy correctly routes requests
|
||||
|
||||
### ❌ Test Fails When
|
||||
|
||||
- Outputs differ between baseline and disagg
|
||||
- Server startup fails
|
||||
- Encoder cache not found (should fallback to local execution)
|
||||
- Proxy routing errors
|
||||
|
||||
## Notes
|
||||
|
||||
- The test uses deterministic generation (`temperature=0.0`, `seed=42`)
|
||||
- Encoder cache should enable exact output reproduction
|
||||
- Test cleans up all instances and cache files after completion
|
||||
- Safe to run multiple times (idempotent)
|
||||
- We setup the PD disagg part with NixlConnector. Please read details about EPD in `examples/online_serving/disaggregated_encoder/README.md`
|
||||
|
||||
## Requirements
|
||||
|
||||
- Multiple GPUs (3 for 1E+1P+1D, 2 for 1E+1PD, 1 for baseline)
|
||||
- 1E+1P+1D is runnable with 2 GPU by assign E and P on the same GPU now.
|
||||
- Multimodal model (e.g., Qwen2.5-VL-3B-Instruct)
|
||||
- Internet access (for accessing vllm test images)
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Logs
|
||||
|
||||
Logs and baseline output are saved in `/tmp/` by default.
|
||||
Can be customized by changing the environment variables.
|
||||
|
||||
### Check Encoder Cache
|
||||
|
||||
```bash
|
||||
# Verify cache files are created
|
||||
ls -la $EC_SHARED_STORAGE_PATH/
|
||||
|
||||
# Should see directories with mm_hash names
|
||||
# Each containing encoder_cache.safetensors
|
||||
```
|
||||
|
||||
### Manual Testing
|
||||
|
||||
Run individual components:
|
||||
|
||||
```bash
|
||||
# Baseline only
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode baseline \
|
||||
--baseline_file test_output.txt \
|
||||
--use_mm_prompts
|
||||
|
||||
# Disagg only (requires baseline output file!)
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode disagg \
|
||||
--baseline_file test_output.txt \
|
||||
--use_mm_prompts
|
||||
```
|
||||
BIN
tests/v1/ec_connector/integration/hato.jpg
Normal file
BIN
tests/v1/ec_connector/integration/hato.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 821 KiB |
476
tests/v1/ec_connector/integration/run_epd_correctness_test.sh
Normal file
476
tests/v1/ec_connector/integration/run_epd_correctness_test.sh
Normal file
@ -0,0 +1,476 @@
|
||||
#!/bin/bash
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# EPD (Encoder-Prefill-Decode) Correctness Test
|
||||
#
|
||||
# This script tests that EPD disaggregation produces the same outputs as baseline.
|
||||
# It runs:
|
||||
# 1. Baseline: Single vLLM instance
|
||||
# 2. EPD: 1E + 1PD setup
|
||||
# 3. Baseline for (E + P + D): 1P + 1D vLLM instances disagg
|
||||
# 4. EPD: 1E + 1P + 1D setup
|
||||
|
||||
# For GPU usage
|
||||
|
||||
# set -xe
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# Model to test
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
|
||||
|
||||
# Set 1 to use multimodal prompts; else to use text-only
|
||||
USE_MM_PROMPTS="${USE_MM_PROMPTS:-1}"
|
||||
MM_FLAG=""
|
||||
if [ $USE_MM_PROMPTS = "1" ]; then
|
||||
MM_FLAG="--use_mm_prompts"
|
||||
fi
|
||||
|
||||
# GPU configuration
|
||||
GPU_E="${GPU_E:-0}"
|
||||
GPU_P="${GPU_P:-1}"
|
||||
GPU_D="${GPU_D:-2}"
|
||||
GPU_SINGLE="${GPU_SINGLE:-$GPU_P}"
|
||||
GPU_PD="${GPU_PD:-$GPU_P}"
|
||||
|
||||
# Port
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_PORT="${PREFILL_PORT:-19535}"
|
||||
DECODE_PORT="${DECODE_PORT:-19536}"
|
||||
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19537}"
|
||||
ENDPOINT_PORT="${ENDPOINT_PORT:-10001}"
|
||||
|
||||
# Storage path for encoder cache
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache_test}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-600}"
|
||||
|
||||
# Output file for baseline comparison and logs
|
||||
LOG_PATH="${LOG_PATH:-/tmp}"
|
||||
BASELINE_FILE="${BASELINE_FILE:-/tmp/vllm_baseline.txt}"
|
||||
BASELINE_PD_FILE="${BASELINE_PD_FILE:-/tmp/vllm_epd_baseline.txt}"
|
||||
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Wait for server to be ready
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:${port}/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
pkill -f "disagg_epd_proxy.py" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Function to run baseline (single instance)
|
||||
run_baseline() {
|
||||
echo "================================"
|
||||
echo "Running BASELINE (single instance)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local PORT=$ENDPOINT_PORT
|
||||
|
||||
# Start baseline instance
|
||||
echo "Starting baseline instance on GPU $GPU_SINGLE, port $PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_SINGLE" vllm serve "$MODEL" \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
> $LOG_PATH/baseline.log 2>&1 &
|
||||
|
||||
local BASELINE_PID=$!
|
||||
|
||||
# Wait for baseline to start
|
||||
echo "Waiting for baseline instance to start..."
|
||||
wait_for_server $PORT
|
||||
|
||||
curl http://127.0.0.1:$PORT/v1/models
|
||||
echo ""
|
||||
|
||||
# Run test in baseline mode
|
||||
echo "Running baseline..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode baseline \
|
||||
--baseline_file "$BASELINE_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup baseline
|
||||
echo "Stopping baseline instance..."
|
||||
kill $BASELINE_PID 2>/dev/null || true
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Function to run EPD with 1E + 1PD
|
||||
run_epd_1e_1pd() {
|
||||
echo "================================"
|
||||
echo "Running EPD (1E + 1PD)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local ENCODE_PORT=$ENCODE_PORT
|
||||
local PREFILL_DECODE_PORT=$PREFILL_DECODE_PORT
|
||||
local PROXY_PORT=$ENDPOINT_PORT
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
# Start encoder instance
|
||||
echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--port $ENCODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
> $LOG_PATH/1e1pd_encoder.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start prefill+decode instance
|
||||
echo "Starting PD instance on GPU $GPU_PD, port $PREFILL_DECODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
|
||||
--port $PREFILL_DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
> $LOG_PATH/1e1pd_pd.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for instances to start
|
||||
echo "Waiting for encoder instance..."
|
||||
wait_for_server $ENCODE_PORT
|
||||
echo "Waiting for PD instance..."
|
||||
wait_for_server $PREFILL_DECODE_PORT
|
||||
|
||||
# Start proxy
|
||||
echo "Starting EPD proxy on port $PROXY_PORT"
|
||||
python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \
|
||||
--host "0.0.0.0" \
|
||||
--port $PROXY_PORT \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
|
||||
> $LOG_PATH/1e1pd_proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for proxy
|
||||
echo "Waiting for proxy..."
|
||||
wait_for_server $PROXY_PORT
|
||||
|
||||
curl http://127.0.0.1:$PROXY_PORT/v1/models
|
||||
curl http://127.0.0.1:$PROXY_PORT/health
|
||||
echo ""
|
||||
|
||||
echo "All EPD (1E+1PD) services are up!"
|
||||
|
||||
# Run test in disagg mode
|
||||
echo "Running EPD (1E+1PD) correctness test..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PROXY_PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode disagg \
|
||||
--baseline_file "$BASELINE_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup
|
||||
echo "✓✓ 1E+1PD Correctness Test finished"
|
||||
echo "Stopping EPD (1E+1PD) instances..."
|
||||
for pid in "${PIDS[@]}"; do
|
||||
kill $pid 2>/dev/null || true
|
||||
done
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Function to run baseline for 1E + 1P + 1D (PD disagg)
|
||||
run_baseline_1p_1d() {
|
||||
echo "================================"
|
||||
echo "Running PD BASELINE (1P + 1D)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local PREFILL_PORT=$PREFILL_PORT
|
||||
local DECODE_PORT=$DECODE_PORT
|
||||
local PROXY_PORT=$ENDPOINT_PORT
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
# Start prefill instance
|
||||
echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $PREFILL_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}' \
|
||||
> $LOG_PATH/1p1d_prefill.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start decode instance
|
||||
echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}' \
|
||||
> $LOG_PATH/1p1d_decode.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for instances to start
|
||||
echo "Waiting for prefill instance..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Start proxy
|
||||
echo "Starting EPD proxy on port $PROXY_PORT"
|
||||
python "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||
--host "0.0.0.0" \
|
||||
--port $PROXY_PORT \
|
||||
--prefiller-ports $PREFILL_PORT \
|
||||
--decoder-ports $DECODE_PORT \
|
||||
> $LOG_PATH/1p1d_proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for proxy
|
||||
echo "Waiting for proxy..."
|
||||
wait_for_server $PROXY_PORT
|
||||
|
||||
curl http://127.0.0.1:$PROXY_PORT/healthcheck
|
||||
echo ""
|
||||
|
||||
echo "All PD (1P+1D) services are up!"
|
||||
|
||||
# Run test in baseline mode
|
||||
echo "Running PD disagg baseline..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PROXY_PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode baseline_pd \
|
||||
--baseline_file "$BASELINE_PD_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup
|
||||
echo "Stopping PD (1P+1D) instances..."
|
||||
for pid in "${PIDS[@]}"; do
|
||||
kill $pid 2>/dev/null || true
|
||||
done
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Function to run EPD with 1E + 1P + 1D
|
||||
run_epd_1e_1p_1d() {
|
||||
echo "================================"
|
||||
echo "Running EPD (1E + 1P + 1D)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local ENCODE_PORT=$ENCODE_PORT
|
||||
local PREFILL_PORT=$PREFILL_PORT
|
||||
local DECODE_PORT=$DECODE_PORT
|
||||
local PROXY_PORT=$ENDPOINT_PORT
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
# Start encoder instance
|
||||
echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--port $ENCODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
> $LOG_PATH/1e1p1d_encoder.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start prefill instance
|
||||
echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $PREFILL_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECSharedStorageConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}' \
|
||||
> $LOG_PATH/1e1p1d_prefill.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start decode instance
|
||||
echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}' \
|
||||
> $LOG_PATH/1e1p1d_decode.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for instances to start
|
||||
echo "Waiting for encoder instance..."
|
||||
wait_for_server $ENCODE_PORT
|
||||
echo "Waiting for prefill instance..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Start proxy
|
||||
echo "Starting EPD proxy on port $PROXY_PORT"
|
||||
python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \
|
||||
--host "0.0.0.0" \
|
||||
--port $PROXY_PORT \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
|
||||
--decode-servers-urls "http://localhost:$DECODE_PORT" \
|
||||
> $LOG_PATH/1e1p1d_proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for proxy
|
||||
echo "Waiting for proxy..."
|
||||
wait_for_server $PROXY_PORT
|
||||
|
||||
curl http://127.0.0.1:$PROXY_PORT/v1/models
|
||||
curl http://127.0.0.1:$PROXY_PORT/health
|
||||
echo ""
|
||||
|
||||
echo "All EPD (1E+1P+1D) services are up!"
|
||||
|
||||
# Run test in disagg mode
|
||||
echo "Running EPD (1E+1P+1D) correctness test..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PROXY_PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode disagg \
|
||||
--baseline_file "$BASELINE_PD_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup
|
||||
echo "✓✓ 1E+1P+1D Correctness Test finished"
|
||||
echo "Stopping EPD (1E+1P+1D) instances..."
|
||||
for pid in "${PIDS[@]}"; do
|
||||
kill $pid 2>/dev/null || true
|
||||
done
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Main execution
|
||||
echo "================================"
|
||||
echo "EPD Correctness Test Suite"
|
||||
echo "Model: $MODEL"
|
||||
echo "================================"
|
||||
|
||||
# Step 1: Run baseline
|
||||
run_baseline
|
||||
|
||||
# Step 2: Test 1E + 1PD
|
||||
run_epd_1e_1pd
|
||||
|
||||
# Step 3: Test baseline 1P + 1D
|
||||
run_baseline_1p_1d
|
||||
|
||||
# Step 4: Test 1E + 1P + 1D
|
||||
run_epd_1e_1p_1d
|
||||
|
||||
# Cleanup output file
|
||||
rm -f "$BASELINE_FILE"
|
||||
rm -f "$BASELINE_PD_FILE"
|
||||
|
||||
echo "================================"
|
||||
echo "✓✓ All EPD correctness tests finished!"
|
||||
echo "================================"
|
||||
305
tests/v1/ec_connector/integration/test_epd_correctness.py
Normal file
305
tests/v1/ec_connector/integration/test_epd_correctness.py
Normal file
@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
EPD Correctness Test
|
||||
|
||||
Tests that EPD (Encoder-Prefill-Decode) disaggregation produces the same
|
||||
outputs as a baseline single instance.
|
||||
|
||||
Usage:
|
||||
# Baseline mode (saves outputs):
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode baseline \
|
||||
--baseline_file .vllm_epd_baseline.txt
|
||||
|
||||
# Disagg mode (compares outputs):
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode disagg \
|
||||
--baseline_file .vllm_epd_baseline.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
# Sample prompts with multimodal content
|
||||
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
|
||||
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
|
||||
|
||||
image_local_path = f"{os.path.dirname(os.path.abspath(__file__))}/hato.jpg"
|
||||
|
||||
SAMPLE_PROMPTS_MM: list[dict] = [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_1)}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "Single image query",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_2)}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"file://{image_local_path}"},
|
||||
},
|
||||
{"type": "text", "text": "Describe these 2 images in detail."},
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "2 images with detailed query",
|
||||
},
|
||||
]
|
||||
|
||||
# Text-only prompts for mixed testing
|
||||
SAMPLE_PROMPTS_TEXT: list[dict] = [
|
||||
{
|
||||
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||
"description": "Simple text-only query",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Explain quantum computing in simple terms."}
|
||||
],
|
||||
"description": "Text-only explanation request",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def check_vllm_server(url: str, timeout=5, retries=10) -> bool:
|
||||
"""Check if the vLLM server is ready.
|
||||
|
||||
Args:
|
||||
url: The URL to check (usually /health or /healthcheck endpoint)
|
||||
timeout: Timeout in seconds for each request
|
||||
retries: Number of retries if the server is not ready
|
||||
|
||||
Returns:
|
||||
True if the server is ready, False otherwise
|
||||
"""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
if response.status_code == 200:
|
||||
print(f"Server is ready at {url}")
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"Attempt {attempt + 1}/{retries}: Server returned "
|
||||
f"status code {response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Attempt {attempt + 1}/{retries}: Error connecting: {e}")
|
||||
time.sleep(2) # Wait before retrying
|
||||
return False
|
||||
|
||||
|
||||
def run_chat_completion(
|
||||
base_url: str,
|
||||
model_name: str,
|
||||
messages: list,
|
||||
max_tokens: int = MAX_OUTPUT_LEN,
|
||||
) -> str:
|
||||
"""Run a chat completion request.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the vLLM server
|
||||
model_name: Name of the model
|
||||
messages: Messages for chat completion
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Generated text content
|
||||
"""
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test function."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="EPD correctness test - compare disagg vs baseline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--service_url",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The vLLM service URL (e.g., http://localhost:8000)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="baseline",
|
||||
choices=["baseline", "baseline_pd", "disagg"],
|
||||
help="Mode: baseline/baseline_pd (saves outputs) or disagg (compares outputs)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--baseline_file",
|
||||
type=str,
|
||||
default=".vllm_epd_baseline.txt",
|
||||
help="File to save/load baseline outputs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_mm_prompts",
|
||||
action="store_true",
|
||||
help="Use multimodal prompts (default: use text-only for quick testing)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Service URL: {args.service_url}")
|
||||
print(f"Model: {args.model_name}")
|
||||
print(f"Mode: {args.mode}")
|
||||
print(f"Output file: {args.baseline_file}")
|
||||
print(f"Use MM prompts: {args.use_mm_prompts}")
|
||||
|
||||
# Determine health check endpoint
|
||||
if args.mode == "baseline":
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
elif args.mode == "baseline_pd":
|
||||
# Nixl toy proxy use /healthcheck
|
||||
health_check_url = f"{args.service_url}/healthcheck"
|
||||
else:
|
||||
# Disagg EPD proxy uses /health
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
if not os.path.exists(args.baseline_file):
|
||||
raise ValueError(
|
||||
f"In disagg mode, the output file {args.baseline_file} from "
|
||||
"baseline does not exist. Run baseline mode first."
|
||||
)
|
||||
|
||||
# Check if server is ready
|
||||
if not check_vllm_server(health_check_url):
|
||||
raise RuntimeError(f"vLLM server at {args.service_url} is not ready!")
|
||||
|
||||
# Select prompts to use
|
||||
if args.use_mm_prompts:
|
||||
test_prompts = SAMPLE_PROMPTS_MM
|
||||
print("Using multimodal prompts")
|
||||
else:
|
||||
test_prompts = SAMPLE_PROMPTS_TEXT
|
||||
print("Using text-only prompts for quick testing")
|
||||
|
||||
# Run completions
|
||||
service_url = f"{args.service_url}/v1"
|
||||
output_strs = {}
|
||||
|
||||
for i, prompt_data in enumerate(test_prompts):
|
||||
print(
|
||||
f"\nRunning prompt {i + 1}/{len(test_prompts)}: {
|
||||
prompt_data['description']
|
||||
}"
|
||||
)
|
||||
|
||||
output_str = run_chat_completion(
|
||||
base_url=service_url,
|
||||
model_name=args.model_name,
|
||||
messages=prompt_data["messages"],
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
)
|
||||
|
||||
# Use description as key for comparison
|
||||
key = prompt_data["description"]
|
||||
output_strs[key] = output_str
|
||||
print(f"Output: {output_str}")
|
||||
|
||||
if args.mode in ("baseline", "baseline_pd"):
|
||||
# Baseline mode: Save outputs
|
||||
print(f"\nSaving baseline outputs to {args.baseline_file}")
|
||||
try:
|
||||
with open(args.baseline_file, "w") as json_file:
|
||||
json.dump(output_strs, json_file, indent=4)
|
||||
print("✅ Baseline outputs saved successfully")
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
else:
|
||||
# Disagg mode: Load and compare outputs
|
||||
print(f"\nLoading baseline outputs from {args.baseline_file}")
|
||||
baseline_outputs = None
|
||||
try:
|
||||
with open(args.baseline_file) as json_file:
|
||||
baseline_outputs = json.load(json_file)
|
||||
except OSError as e:
|
||||
print(f"Error reading from file: {e}")
|
||||
raise
|
||||
|
||||
# Verify outputs match
|
||||
print("\nComparing disagg outputs with baseline...")
|
||||
assert isinstance(baseline_outputs, dict), "Baseline outputs should be a dict"
|
||||
assert len(baseline_outputs) == len(output_strs), (
|
||||
f"Length mismatch: baseline has {len(baseline_outputs)}, "
|
||||
f"disagg has {len(output_strs)}"
|
||||
)
|
||||
|
||||
all_match = True
|
||||
for key, baseline_output in baseline_outputs.items():
|
||||
assert key in output_strs, f"{key} not in disagg outputs"
|
||||
|
||||
disagg_output = output_strs[key]
|
||||
if baseline_output == disagg_output:
|
||||
print(f"✅ {key}: MATCH")
|
||||
else:
|
||||
print(f"❌ {key}: MISMATCH")
|
||||
print(f" Baseline: {baseline_output}")
|
||||
print(f" Disagg: {disagg_output}")
|
||||
all_match = False
|
||||
|
||||
assert all_match, "❌❌Disagg outputs do not match baseline!❌❌"
|
||||
if all_match:
|
||||
print("\n✅ All outputs match! Test PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
609
tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py
Normal file
609
tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py
Normal file
@ -0,0 +1,609 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for ECSharedStorageConnector.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole
|
||||
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
|
||||
ECSharedStorageConnector,
|
||||
ECSharedStorageConnectorMetadata,
|
||||
MMMeta,
|
||||
)
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
# ------------------ Mock Classes ------------------ #
|
||||
class MockRequest:
|
||||
def __init__(self, request_id, mm_hashes: list[str], token_counts: list[int]):
|
||||
assert len(mm_hashes) == len(token_counts)
|
||||
self.request_id = request_id
|
||||
self._token_counts = token_counts
|
||||
self.mm_features = []
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
feature = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier=mm_hash,
|
||||
mm_position=PlaceholderRange(offset=0, length=self._token_counts[i]),
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
assert input_id < len(self._token_counts)
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage(tmp_path):
|
||||
"""Fixture providing temporary storage path."""
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_config_producer(temp_storage):
|
||||
"""Fixture providing mock VllmConfig for producer role."""
|
||||
config = Mock(spec=VllmConfig)
|
||||
config.ec_transfer_config = Mock()
|
||||
config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage)
|
||||
config.ec_transfer_config.is_ec_producer = True
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_config_consumer(temp_storage):
|
||||
"""Fixture providing mock VllmConfig for consumer role."""
|
||||
config = Mock(spec=VllmConfig)
|
||||
config.ec_transfer_config = Mock()
|
||||
config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage)
|
||||
config.ec_transfer_config.is_ec_producer = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_with_3_mm():
|
||||
"""Fixture providing mock Request with 3 multimodal items."""
|
||||
request_id = "test_req_123"
|
||||
mm_hashes = ["img_hash_1", "img_hash_2", "img_hash_3"]
|
||||
token_counts = [100, 150, 200]
|
||||
|
||||
request = MockRequest(request_id, mm_hashes, token_counts)
|
||||
return request
|
||||
|
||||
|
||||
# ------------------ Unit Tests ------------------ #
|
||||
class TestECSharedStorageConnectorBasics:
|
||||
"""Test basic EC connector functionality."""
|
||||
|
||||
def test_initialization_producer(self, mock_vllm_config_producer, temp_storage):
|
||||
"""Test connector initializes correctly as producer."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
assert connector.role == ECConnectorRole.SCHEDULER
|
||||
assert connector.is_producer
|
||||
assert connector._storage_path == temp_storage
|
||||
assert connector._mm_datas_need_loads == {}
|
||||
|
||||
def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage):
|
||||
"""Test connector initializes correctly as consumer."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
assert connector.role == ECConnectorRole.WORKER
|
||||
assert not connector.is_producer
|
||||
assert connector._storage_path == temp_storage
|
||||
|
||||
def test_role_assignment(self, mock_vllm_config_producer):
|
||||
"""Test role is correctly assigned."""
|
||||
scheduler_connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
worker_connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
assert scheduler_connector.role == ECConnectorRole.SCHEDULER
|
||||
assert worker_connector.role == ECConnectorRole.WORKER
|
||||
|
||||
|
||||
class TestCacheExistence:
|
||||
"""Test cache existence checking using has_caches() API."""
|
||||
|
||||
def test_has_caches_all_exist_3_items(
|
||||
self,
|
||||
mock_vllm_config_producer,
|
||||
mock_vllm_config_consumer,
|
||||
mock_request_with_3_mm,
|
||||
):
|
||||
"""Test has_caches returns True when all 3 caches exist."""
|
||||
# Test for producer first
|
||||
producer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Create cache files using save_caches (proper way)
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
for mm_feature in mock_request_with_3_mm.mm_features:
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_cache[mm_hash] = torch.randn(10, 768)
|
||||
producer.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
# Test using has_caches API
|
||||
producer_result = producer.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(producer_result) == 3
|
||||
assert all(producer_result), f"Expected all True, got {producer_result}"
|
||||
|
||||
# Also test consumer can check if cache exists
|
||||
consumer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Test using has_caches API
|
||||
consumer_result = consumer.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(consumer_result) == 3
|
||||
assert all(consumer_result), f"Expected all True, got {consumer_result}"
|
||||
|
||||
def test_has_caches_none_exist(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test has_caches returns False when no caches exist."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Test without creating any files
|
||||
result = connector.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert not any(result), f"Expected all False, got {result}"
|
||||
|
||||
def test_has_caches_partial_exist(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test has_caches with some caches existing (1 of 3)."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Create only the second cache file
|
||||
mm_hash_second = mock_request_with_3_mm.mm_features[1].identifier
|
||||
encoder_cache = {mm_hash_second: torch.randn(10, 768)}
|
||||
connector.save_caches(encoder_cache, mm_hash_second)
|
||||
|
||||
# Test
|
||||
result = connector.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert not result[0] # First doesn't exist
|
||||
assert result[1] # Second exists
|
||||
assert not result[2] # Third doesn't exist
|
||||
|
||||
|
||||
class TestStateManagement:
|
||||
"""Test connector state management."""
|
||||
|
||||
def test_update_state_after_alloc_3_items(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test state update after allocation for 3 MM items."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Initial state should be empty
|
||||
assert len(connector._mm_datas_need_loads) == 0
|
||||
|
||||
# Update state for all 3 items
|
||||
for i in range(3):
|
||||
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
|
||||
|
||||
# Check state updated for all 3
|
||||
assert len(connector._mm_datas_need_loads) == 3
|
||||
assert "img_hash_1" in connector._mm_datas_need_loads
|
||||
assert "img_hash_2" in connector._mm_datas_need_loads
|
||||
assert "img_hash_3" in connector._mm_datas_need_loads
|
||||
assert connector._mm_datas_need_loads["img_hash_1"] == 100
|
||||
assert connector._mm_datas_need_loads["img_hash_2"] == 150
|
||||
assert connector._mm_datas_need_loads["img_hash_3"] == 200
|
||||
|
||||
def test_build_connector_meta_3_items(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test metadata building for 3 MM items."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Setup state for all 3 items
|
||||
for i in range(3):
|
||||
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
|
||||
|
||||
# Build metadata
|
||||
scheduler_output = Mock(spec=SchedulerOutput)
|
||||
metadata = connector.build_connector_meta(scheduler_output)
|
||||
|
||||
# Assert
|
||||
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
|
||||
assert len(metadata.mm_datas) == 3
|
||||
assert metadata.mm_datas[0].mm_hash == "img_hash_1"
|
||||
assert metadata.mm_datas[0].num_token == 100
|
||||
assert metadata.mm_datas[1].mm_hash == "img_hash_2"
|
||||
assert metadata.mm_datas[1].num_token == 150
|
||||
assert metadata.mm_datas[2].mm_hash == "img_hash_3"
|
||||
assert metadata.mm_datas[2].num_token == 200
|
||||
|
||||
# State should be cleared after building
|
||||
assert len(connector._mm_datas_need_loads) == 0
|
||||
|
||||
def test_build_connector_meta_empty(self, mock_vllm_config_producer):
|
||||
"""Test metadata building with empty state."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
scheduler_output = Mock(spec=SchedulerOutput)
|
||||
metadata = connector.build_connector_meta(scheduler_output)
|
||||
|
||||
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
|
||||
assert len(metadata.mm_datas) == 0
|
||||
|
||||
def test_state_cleared_after_metadata_build(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test that state is properly cleared after building metadata."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Add state
|
||||
for i in range(3):
|
||||
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
|
||||
assert len(connector._mm_datas_need_loads) == 3
|
||||
|
||||
# Build metadata (should clear state)
|
||||
scheduler_output = Mock(spec=SchedulerOutput)
|
||||
connector.build_connector_meta(scheduler_output)
|
||||
|
||||
# State should be empty
|
||||
assert len(connector._mm_datas_need_loads) == 0
|
||||
|
||||
# Build again should return empty metadata
|
||||
metadata2 = connector.build_connector_meta(scheduler_output)
|
||||
assert len(metadata2.mm_datas) == 0
|
||||
|
||||
|
||||
class TestCacheSaving:
|
||||
"""Test encoder cache saving (producer only)."""
|
||||
|
||||
def test_save_caches_producer_3_items(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage
|
||||
):
|
||||
"""Test cache saving as producer for 3 different MM items."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Create and save 3 different caches
|
||||
mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features]
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
for mm_hash in mm_hashes:
|
||||
encoder_cache[mm_hash] = torch.randn(10, 768)
|
||||
connector.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
# Verify all files exist using has_caches
|
||||
result = connector.has_caches(mock_request_with_3_mm)
|
||||
assert all(result), f"Not all caches were saved: {result}"
|
||||
|
||||
# Verify each file's content
|
||||
for mm_hash in mm_hashes:
|
||||
filename = connector._generate_filename_debug(mm_hash)
|
||||
loaded = safetensors.torch.load_file(filename)
|
||||
assert "ec_cache" in loaded
|
||||
assert torch.allclose(loaded["ec_cache"], encoder_cache[mm_hash].cpu())
|
||||
|
||||
def test_save_caches_consumer_skips(self, mock_vllm_config_consumer):
|
||||
"""Test cache saving is skipped for consumer."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "test_hash_consumer"
|
||||
encoder_cache = {mm_hash: torch.randn(10, 768)}
|
||||
|
||||
# Save should not raise but also not create file
|
||||
connector.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
# Verify file doesn't exist using has_caches
|
||||
mock_request = MockRequest("req_consumer", [mm_hash], [10])
|
||||
result = connector.has_caches(mock_request)
|
||||
assert not result[0], "Consumer should not save caches"
|
||||
|
||||
|
||||
class TestCacheLoading:
|
||||
"""Test encoder cache loading (consumer)."""
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_start_load_caches_consumer_3_items(
|
||||
self,
|
||||
mock_vllm_config_producer,
|
||||
mock_vllm_config_consumer,
|
||||
mock_request_with_3_mm,
|
||||
temp_storage,
|
||||
):
|
||||
"""Test consumer loads 3 caches from storage."""
|
||||
# First, create producer to save caches
|
||||
producer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Producer saves 3 caches
|
||||
mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features]
|
||||
saved_caches = {}
|
||||
for mm_hash in mm_hashes:
|
||||
saved_caches[mm_hash] = torch.randn(10, 768)
|
||||
producer.save_caches(saved_caches, mm_hash)
|
||||
|
||||
# Now consumer loads
|
||||
consumer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Setup metadata for all 3
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
for mm_hash in mm_hashes:
|
||||
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
|
||||
consumer.bind_connector_metadata(metadata)
|
||||
|
||||
# Load
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
consumer.start_load_caches(encoder_cache=encoder_cache)
|
||||
|
||||
# Verify all 3 loaded
|
||||
assert len(encoder_cache) == 3
|
||||
for mm_hash in mm_hashes:
|
||||
assert mm_hash in encoder_cache, f"{mm_hash} missing in encoder_cache"
|
||||
assert encoder_cache[mm_hash].is_cuda, (
|
||||
f"{mm_hash} cache is in {encoder_cache[mm_hash].device}"
|
||||
)
|
||||
assert torch.allclose(
|
||||
encoder_cache[mm_hash].cpu(), saved_caches[mm_hash]
|
||||
), f"{mm_hash} cache saved and loaded tesnor are not the same"
|
||||
|
||||
def test_start_load_caches_skip_existing(
|
||||
self, mock_vllm_config_producer, mock_vllm_config_consumer, temp_storage
|
||||
):
|
||||
"""Test cache loading skips already cached items."""
|
||||
# Setup: producer saves cache
|
||||
producer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "existing_hash"
|
||||
saved_cache = torch.randn(10, 768)
|
||||
producer.save_caches({mm_hash: saved_cache}, mm_hash)
|
||||
|
||||
# Consumer setup
|
||||
consumer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
|
||||
consumer.bind_connector_metadata(metadata)
|
||||
|
||||
# Pre-populate encoder_cache with different value
|
||||
existing_cache = torch.randn(5, 512)
|
||||
encoder_cache = {mm_hash: existing_cache}
|
||||
|
||||
# Load (should skip since already exists)
|
||||
with patch("safetensors.torch.load_file") as mock_load:
|
||||
consumer.start_load_caches(encoder_cache=encoder_cache)
|
||||
# Should not call load_file since cache exists
|
||||
mock_load.assert_not_called()
|
||||
|
||||
# Verify original cache unchanged
|
||||
assert torch.equal(encoder_cache[mm_hash], existing_cache)
|
||||
|
||||
def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test loading with empty metadata does nothing."""
|
||||
consumer = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Setup empty metadata
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
consumer.bind_connector_metadata(metadata)
|
||||
|
||||
# Load (should not raise)
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
consumer.start_load_caches(encoder_cache=encoder_cache)
|
||||
|
||||
# Cache should remain empty
|
||||
assert len(encoder_cache) == 0
|
||||
|
||||
|
||||
class TestFilenameGeneration:
|
||||
"""Test filename and path generation."""
|
||||
|
||||
def test_generate_foldername(self, mock_vllm_config_producer, temp_storage):
|
||||
"""Test folder name generation."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "test_folder_hash"
|
||||
folder = connector._generate_foldername_debug(mm_hash)
|
||||
|
||||
assert folder == os.path.join(temp_storage, mm_hash)
|
||||
assert os.path.isdir(folder) # Should be created
|
||||
|
||||
def test_generate_filename(self, mock_vllm_config_producer, temp_storage):
|
||||
"""Test filename generation."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "test_file_hash"
|
||||
filename = connector._generate_filename_debug(mm_hash)
|
||||
|
||||
expected = os.path.join(temp_storage, mm_hash, "encoder_cache.safetensors")
|
||||
assert filename == expected
|
||||
assert os.path.isdir(os.path.dirname(filename)) # Folder created
|
||||
|
||||
def test_generate_filename_consistency(self, mock_vllm_config_producer):
|
||||
"""Test filename generation is consistent."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "consistency_hash"
|
||||
filename1 = connector._generate_filename_debug(mm_hash)
|
||||
filename2 = connector._generate_filename_debug(mm_hash)
|
||||
|
||||
assert filename1 == filename2
|
||||
|
||||
|
||||
class TestMetadataBindingLifecycle:
|
||||
"""Test metadata binding and clearing lifecycle."""
|
||||
|
||||
def test_bind_connector_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test binding connector metadata."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
metadata.add_mm_data(MMMeta.make_meta("hash_1", 100))
|
||||
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
assert connector._connector_metadata is metadata
|
||||
|
||||
def test_clear_connector_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test clearing connector metadata."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
connector.clear_connector_metadata()
|
||||
|
||||
assert connector._connector_metadata is None
|
||||
|
||||
def test_get_connector_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test getting connector metadata."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
retrieved = connector._get_connector_metadata()
|
||||
|
||||
assert retrieved is metadata
|
||||
|
||||
def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer):
|
||||
"""Test getting metadata when not set raises."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
connector._get_connector_metadata()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_save_empty_cache(self, mock_vllm_config_producer):
|
||||
"""Test saving empty tensor."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "empty_hash"
|
||||
encoder_cache = {mm_hash: torch.empty(0)}
|
||||
|
||||
# Should not raise
|
||||
connector.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
def test_load_nonexistent_cache(self, mock_vllm_config_consumer):
|
||||
"""Test loading cache that doesn't exist raises error."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECSharedStorageConnectorMetadata()
|
||||
metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100))
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Should raise FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError):
|
||||
connector.start_load_caches(encoder_cache=encoder_cache)
|
||||
|
||||
def test_has_caches_empty_request(self, mock_vllm_config_producer):
|
||||
"""Test has_caches with request that has no MM data."""
|
||||
connector = ECSharedStorageConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
mock_request = MockRequest("req_empty", [], [])
|
||||
|
||||
result = connector.has_caches(mock_request)
|
||||
|
||||
assert len(result) == 0
|
||||
assert result == []
|
||||
@ -10,6 +10,14 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ECTransferConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
@ -450,3 +458,141 @@ def test_engine_core_invalid_request_id_type():
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(valid_request))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize(
|
||||
("ec_role", "gpu_memory_utilization", "enable_prefix_caching"),
|
||||
[
|
||||
("ec_producer", 0.01, False),
|
||||
# NOTE: ec_producer never allows prefix caching
|
||||
("ec_consumer", 0.7, True),
|
||||
("ec_consumer", 0.7, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
||||
def test_encoder_instance_zero_kv_cache(
|
||||
ec_role: str,
|
||||
gpu_memory_utilization: float,
|
||||
enable_prefix_caching: bool,
|
||||
use_kv_connector: bool,
|
||||
):
|
||||
"""EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests
|
||||
|
||||
This test verifies encoder-only instance initializes with 0 KV cache blocks.
|
||||
Under EPD disagg mode, Encoder instances (EC producer role) only execute
|
||||
vision encoder, so they don't need KV cache for text generation.
|
||||
"""
|
||||
# Form vllm config
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
disable_hybrid_kv_cache_manager=True,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="llava-hf/llava-1.5-7b-hf", # Multimodal model
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
)
|
||||
if use_kv_connector
|
||||
else None
|
||||
)
|
||||
ec_transfer_config = ECTransferConfig(
|
||||
ec_connector="ECSharedStorageConnector",
|
||||
ec_role=ec_role,
|
||||
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"},
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
ec_transfer_config=ec_transfer_config,
|
||||
)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
print(f"executor_class: {executor_class}")
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
|
||||
# Check encoder cache manager exists
|
||||
assert engine_core.scheduler.encoder_cache_manager is not None, (
|
||||
"encoder_cache_manager should exist"
|
||||
)
|
||||
|
||||
if ec_role == "ec_producer":
|
||||
# Check 1: num_blocks should be 0
|
||||
# NOTE: num_blocks=1 as BlockPool always needs a null_block.
|
||||
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
|
||||
print(f"kv_cache_config: {kv_cache_config}")
|
||||
assert kv_cache_config.num_blocks == 1, (
|
||||
f"ec_producer should only have 1 KV blocks, "
|
||||
f"got {kv_cache_config.num_blocks}"
|
||||
)
|
||||
|
||||
# Check 2: kv_cache_groups should be empty
|
||||
assert len(kv_cache_config.kv_cache_groups) == 0, (
|
||||
f"ec_producer should have 0 KV cache groups, "
|
||||
f"got {len(kv_cache_config.kv_cache_groups)}"
|
||||
)
|
||||
|
||||
# Check 3: kv_cache_tensors should be empty
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 0, (
|
||||
f"Encoder instance should have 0 KV cache tensors, "
|
||||
f"got {len(kv_cache_config.kv_cache_tensors)}"
|
||||
)
|
||||
|
||||
# Check 4: Verify EC connector is initialized and is producer
|
||||
assert engine_core.scheduler.ec_connector is not None, (
|
||||
"Encoder instance should have EC connector"
|
||||
)
|
||||
assert engine_core.scheduler.ec_connector.is_producer, (
|
||||
"Encoder instance EC connector should be producer"
|
||||
)
|
||||
|
||||
# Check 5: Verify chunked prefill is disabled
|
||||
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
|
||||
"Encoder instance should disable chunked prefill (no KV cache)"
|
||||
)
|
||||
|
||||
elif ec_role == "ec_consumer":
|
||||
# Check 1: num_blocks should be > 1
|
||||
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
|
||||
print(f"kv_cache_config: {kv_cache_config}")
|
||||
assert kv_cache_config.num_blocks > 1, (
|
||||
f"ec_consumer should have >1 KV blocks, got {kv_cache_config.num_blocks}"
|
||||
)
|
||||
|
||||
# Check 2: kv_cache_groups should NOT be empty
|
||||
assert len(kv_cache_config.kv_cache_groups) > 0, (
|
||||
f"ec_consumer should have KV cache groups, "
|
||||
f"got {len(kv_cache_config.kv_cache_groups)}"
|
||||
)
|
||||
|
||||
# Check 3: Verify EC connector is consumer
|
||||
assert engine_core.scheduler.ec_connector is not None, (
|
||||
"Consumer instance should have EC connector"
|
||||
)
|
||||
assert not engine_core.scheduler.ec_connector.is_producer, (
|
||||
"Consumer instance EC connector should be consumer"
|
||||
)
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm.config.compilation import (
|
||||
PassConfig,
|
||||
)
|
||||
from vllm.config.device import DeviceConfig
|
||||
from vllm.config.ec_transfer import ECTransferConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
@ -54,6 +55,8 @@ __all__ = [
|
||||
"PassConfig",
|
||||
# From vllm.config.device
|
||||
"DeviceConfig",
|
||||
# From vllm.config.ec_transfer
|
||||
"ECTransferConfig",
|
||||
# From vllm.config.kv_events
|
||||
"KVEventsConfig",
|
||||
# From vllm.config.kv_transfer
|
||||
|
||||
110
vllm/config/ec_transfer.py
Normal file
110
vllm/config/ec_transfer.py
Normal file
@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
ECProducer = Literal["ec_producer"]
|
||||
ECConsumer = Literal["ec_consumer"]
|
||||
ECRole = Literal[ECProducer, ECConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ECTransferConfig:
|
||||
"""Configuration for distributed EC cache transfer."""
|
||||
|
||||
ec_connector: str | None = None
|
||||
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str | None = None
|
||||
"""The engine id for EC transfers."""
|
||||
|
||||
ec_buffer_device: str | None = "cuda"
|
||||
"""The device used by ec connector to buffer the EC cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
ec_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
ec_role: ECRole | None = None
|
||||
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
|
||||
are 'ec_producer', 'ec_consumer'."""
|
||||
|
||||
ec_rank: int | None = None
|
||||
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
|
||||
0 for encoder, 1 for pd instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
ec_parallel_size: int = 1
|
||||
"""The number of parallel instances for EC cache transfer. For
|
||||
PyNcclConnector, this should be 2."""
|
||||
|
||||
ec_ip: str = "127.0.0.1"
|
||||
"""The EC connector ip, used to build distributed connection."""
|
||||
|
||||
ec_port: int = 14579
|
||||
"""The EC connector port, used to build distributed connection."""
|
||||
|
||||
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
ec_connector_module_path: str | None = None
|
||||
"""The Python module path to dynamically load the EC connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
|
||||
raise ValueError(
|
||||
f"Unsupported ec_role: {self.ec_role}. "
|
||||
f"Supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
if self.ec_connector is not None and self.ec_role is None:
|
||||
raise ValueError(
|
||||
"Please specify ec_role when ec_connector "
|
||||
f"is set, supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_ec_transfer_instance(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
|
||||
|
||||
@property
|
||||
def is_ec_producer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
|
||||
|
||||
@property
|
||||
def is_ec_consumer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.ec_connector_extra_config.get(key, default)
|
||||
@ -28,6 +28,7 @@ from vllm.utils import random_uuid
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
from .ec_transfer import ECTransferConfig
|
||||
from .kv_events import KVEventsConfig
|
||||
from .kv_transfer import KVTransferConfig
|
||||
from .load import LoadConfig
|
||||
@ -103,6 +104,8 @@ class VllmConfig:
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
"""The configurations for event publishing."""
|
||||
ec_transfer_config: ECTransferConfig | None = None
|
||||
"""The configurations for distributed EC cache transfer."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@ -183,6 +186,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.ec_transfer_config:
|
||||
vllm_factors.append(self.ec_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
|
||||
14
vllm/distributed/ec_transfer/__init__.py
Normal file
14
vllm/distributed/ec_transfer/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_transfer_state import (
|
||||
ensure_ec_transfer_initialized,
|
||||
get_ec_transfer,
|
||||
has_ec_transfer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_ec_transfer",
|
||||
"ensure_ec_transfer_initialized",
|
||||
"has_ec_transfer",
|
||||
]
|
||||
247
vllm/distributed/ec_transfer/ec_connector/base.py
Normal file
247
vllm/distributed/ec_transfer/ec_connector/base.py
Normal file
@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
ECConnectorBase Class for Distributed Encoder Cache &
|
||||
P2P Encoder cache communication in V1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save Encoder cache.
|
||||
check_caches_exist() - Check whether Encoder cache of requests exist
|
||||
update_state_after_alloc() - update ECConnector state after
|
||||
allocate. This will decide to load the cache or not
|
||||
request_finished() - called when a request is finished,
|
||||
free the cache with the requests
|
||||
|
||||
Worker-side: runs in each worker, loads/saves Encoder Cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_ec() - starts loading all ECs (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ECConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class ECConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler ECConnector and Worker ECConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ECConnectorBase(ABC):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
|
||||
self._connector_metadata: ECConnectorMetadata | None = None
|
||||
self._vllm_config = vllm_config
|
||||
self._role = role
|
||||
if vllm_config.ec_transfer_config is not None:
|
||||
self._is_producer = vllm_config.ec_transfer_config.is_ec_producer
|
||||
else:
|
||||
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
|
||||
|
||||
@property
|
||||
def role(self) -> ECConnectorRole:
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def is_producer(self) -> bool:
|
||||
return self._is_producer
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
EC cache loading.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = None
|
||||
|
||||
def _get_connector_metadata(self) -> ECConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
|
||||
# Should only be called while set to valid metadata.
|
||||
assert self._connector_metadata is not None
|
||||
return self._connector_metadata
|
||||
|
||||
def register_caches(
|
||||
self,
|
||||
ec_caches: dict[str, torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Initialize with the EC caches.
|
||||
Args:
|
||||
ec_caches: dictionary of encoder cache
|
||||
"""
|
||||
# TODO: Implement this later for P2P feature
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_caches(
|
||||
self, encoder_cache: dict[str, torch.Tensor], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Start loading the cache from the connector into vLLM's encoder cache.
|
||||
|
||||
This method loads the encoder cache based on metadata provided by the scheduler.
|
||||
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
|
||||
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_caches(
|
||||
self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Save the encoder cache to the connector.
|
||||
|
||||
This method saves the encoder cache from the worker's local storage
|
||||
to shared storage or another external connector.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
mm_hash (str): The hash of the multimodal data whose cache is being saved.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens on the worker.
|
||||
The scheduler process (via the Executors) will use this output
|
||||
to track which workers are done.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def has_caches(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> list[bool]:
|
||||
"""
|
||||
Check if encoder cache exists for each mm data of requests
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
|
||||
Returns:
|
||||
A list bool where ith value is True if cache exist for
|
||||
ith mm_data of requests
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: "Request", index: int):
|
||||
"""
|
||||
Update ECConnector state to decide allocate cache for requests
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ECConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_connector_output(self, connector_output: ECConnectorOutput):
|
||||
"""
|
||||
Update ECConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (ECConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request"
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its encoder cache is freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and cached
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
"""
|
||||
return False, None
|
||||
88
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
88
vllm/distributed/ec_transfer/ec_connector/factory.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# yapf: disable
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ECTransferConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ECConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[ECConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[ECConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: ECConnectorRole,
|
||||
) -> ECConnectorBase:
|
||||
ec_transfer_config = config.ec_transfer_config
|
||||
if ec_transfer_config is None:
|
||||
raise ValueError("ec_transfer_config must be set to create a connector")
|
||||
connector_cls = cls.get_connector_class(ec_transfer_config)
|
||||
logger.info(
|
||||
"Creating connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__,
|
||||
ec_transfer_config.engine_id,
|
||||
)
|
||||
# Connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
# - Should only be used inside the Scheduler class
|
||||
# Worker connector:
|
||||
# - Co-locate with worker process
|
||||
return connector_cls(config, role)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, ec_transfer_config: "ECTransferConfig"
|
||||
) -> type[ECConnectorBase]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = ec_transfer_config.ec_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("EC connect must not be None")
|
||||
elif connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = ec_transfer_config.ec_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
ECConnectorFactory.register_connector(
|
||||
"ECSharedStorageConnector",
|
||||
"vllm.distributed.ec_transfer.ec_connector.shared_storage_connector",
|
||||
"ECSharedStorageConnector",
|
||||
)
|
||||
@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import safetensors
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorMetadata,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMMeta:
|
||||
mm_hash: str
|
||||
num_token: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(mm_hash, num_token) -> "MMMeta":
|
||||
return MMMeta(mm_hash=mm_hash, num_token=num_token)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECSharedStorageConnectorMetadata(ECConnectorMetadata):
|
||||
mm_datas: list[MMMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.mm_datas = []
|
||||
|
||||
def add_mm_data(self, mm_data: MMMeta):
|
||||
self.mm_datas.append(mm_data)
|
||||
|
||||
|
||||
class ECSharedStorageConnector(ECConnectorBase):
|
||||
# NOTE: This is Simple debug implementation of the EC connector.
|
||||
# It save / load the EC cache to / from the disk.
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
# req_id -> index
|
||||
self._mm_datas_need_loads: dict[str, int] = {}
|
||||
transfer_config = vllm_config.ec_transfer_config
|
||||
if transfer_config is not None:
|
||||
self._storage_path = transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
logger.debug(transfer_config)
|
||||
logger.debug("Shared storage path is %s", self._storage_path)
|
||||
else:
|
||||
raise ValueError("ec_transfer_config must be set for ECConnectorBase")
|
||||
|
||||
def start_load_caches(self, encoder_cache, **kwargs) -> None:
|
||||
"""
|
||||
Start loading the cache from the connector into vLLM's encoder cache.
|
||||
|
||||
This method loads the encoder cache based on metadata provided by the scheduler.
|
||||
It is called before `_gather_mm_embeddings` for the EC Connector. For EC,
|
||||
the `encoder_cache` and `mm_hash` are stored in `kwargs`.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
|
||||
# Get the metadata
|
||||
metadata: ECConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
|
||||
assert encoder_cache is not None
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
(
|
||||
"In connector.start_load_caches, ",
|
||||
"but the connector metadata is None",
|
||||
)
|
||||
)
|
||||
return
|
||||
# Load the EC for each mm data
|
||||
for mm_data in metadata.mm_datas:
|
||||
if mm_data.mm_hash in encoder_cache:
|
||||
continue
|
||||
filename = self._generate_filename_debug(mm_data.mm_hash)
|
||||
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
|
||||
encoder_cache[mm_data.mm_hash] = ec_cache
|
||||
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
|
||||
|
||||
def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
|
||||
"""
|
||||
Save the encoder cache to the connector.
|
||||
|
||||
This method saves the encoder cache from the worker's local storage
|
||||
to shared storage or another external connector.
|
||||
|
||||
Args:
|
||||
encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal
|
||||
data hashes (`mm_hash`) to encoder cache tensors.
|
||||
mm_hash (str): The hash of the multimodal data whose cache is being saved.
|
||||
kwargs (dict): Additional keyword arguments for the connector.
|
||||
"""
|
||||
# Return if it is PD Instance
|
||||
if not self.is_producer:
|
||||
return
|
||||
filename = self._generate_filename_debug(mm_hash)
|
||||
ec_cache = encoder_cache[mm_hash]
|
||||
tensors = {"ec_cache": ec_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
logger.debug("Save cache successful for mm_hash %s", mm_hash)
|
||||
|
||||
def has_caches(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> list[bool]:
|
||||
"""
|
||||
Check if cache exist externally for each mm_data of request
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
|
||||
Returns:
|
||||
List of bool indicate that ith mm_data exist in cache or not
|
||||
"""
|
||||
result = []
|
||||
for feature in request.mm_features:
|
||||
result.append(self._found_match_for_mm_data(feature.identifier))
|
||||
return result
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
index: int,
|
||||
) -> None:
|
||||
"""
|
||||
Update ECConnector state after encoder cache allocation.
|
||||
"""
|
||||
mm_hash = request.mm_features[index].identifier
|
||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
||||
# Insert mm_hash only if this block has not been recorded yet.
|
||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> ECConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
This only build for load mm_data only
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ECSharedStorageConnectorMetadata()
|
||||
for mm_hash, num_encoder_token in self._mm_datas_need_loads.items():
|
||||
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
|
||||
self._mm_datas_need_loads.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_mm_data(self, mm_hash) -> bool:
|
||||
"""Check if the cache is hit for the request."""
|
||||
filename = self._generate_filename_debug(mm_hash)
|
||||
return os.path.exists(filename)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
mm_hash: str,
|
||||
create_folder: bool = True, # <- now defaults to True
|
||||
) -> str:
|
||||
"""
|
||||
Return the folder in which the cache for this mm_hash lives.
|
||||
If `create_folder` is True (default) the directory is created
|
||||
recursively the first time it is needed.
|
||||
"""
|
||||
foldername = os.path.join(self._storage_path, mm_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(self, mm_hash: str) -> str:
|
||||
"""
|
||||
Return the full path of the safetensors file for this mm_hash.
|
||||
Ensures the parent directory exists because
|
||||
`_generate_foldername_debug` is called with its default
|
||||
(`create_folder=True`).
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created
|
||||
return os.path.join(foldername, "encoder_cache.safetensors")
|
||||
46
vllm/distributed/ec_transfer/ec_transfer_state.py
Normal file
46
vllm/distributed/ec_transfer/ec_transfer_state.py
Normal file
@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorBase,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
_EC_CONNECTOR_AGENT: ECConnectorBase | None = None
|
||||
|
||||
|
||||
def get_ec_transfer() -> ECConnectorBase:
|
||||
assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized"
|
||||
return _EC_CONNECTOR_AGENT
|
||||
|
||||
|
||||
def has_ec_transfer() -> bool:
|
||||
return _EC_CONNECTOR_AGENT is not None
|
||||
|
||||
|
||||
def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Initialize EC cache connector.
|
||||
"""
|
||||
|
||||
global _EC_CONNECTOR_AGENT
|
||||
|
||||
if vllm_config.ec_transfer_config is None:
|
||||
return
|
||||
|
||||
if (
|
||||
vllm_config.ec_transfer_config.is_ec_transfer_instance
|
||||
and _EC_CONNECTOR_AGENT is None
|
||||
):
|
||||
if envs.VLLM_USE_V1:
|
||||
_EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector(
|
||||
config=vllm_config, role=ECConnectorRole.WORKER
|
||||
)
|
||||
else:
|
||||
raise ValueError("V0 is no longer supported")
|
||||
@ -38,6 +38,7 @@ from vllm.config import (
|
||||
CompilationConfig,
|
||||
ConfigType,
|
||||
DeviceConfig,
|
||||
ECTransferConfig,
|
||||
EPLBConfig,
|
||||
KVEventsConfig,
|
||||
KVTransferConfig,
|
||||
@ -527,6 +528,8 @@ class EngineArgs:
|
||||
kv_transfer_config: KVTransferConfig | None = None
|
||||
kv_events_config: KVEventsConfig | None = None
|
||||
|
||||
ec_transfer_config: ECTransferConfig | None = None
|
||||
|
||||
generation_config: str = ModelConfig.generation_config
|
||||
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
||||
override_generation_config: dict[str, Any] = get_field(
|
||||
@ -1105,6 +1108,9 @@ class EngineArgs:
|
||||
"--kv-transfer-config", **vllm_kwargs["kv_transfer_config"]
|
||||
)
|
||||
vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"])
|
||||
vllm_group.add_argument(
|
||||
"--ec-transfer-config", **vllm_kwargs["ec_transfer_config"]
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--compilation-config", "-O", **vllm_kwargs["compilation_config"]
|
||||
)
|
||||
@ -1676,6 +1682,7 @@ class EngineArgs:
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
ec_transfer_config=self.ec_transfer_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
|
||||
@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"):
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
if not worker.model_runner.is_pooling_model and all(
|
||||
_is_flashinfer_backend(group.backend)
|
||||
for groups in worker.model_runner.attn_groups
|
||||
for group in groups
|
||||
# NOTE: we add check for empty attn_groups to avoid errors when
|
||||
# deploying models such as E instances and encoder-only models.
|
||||
# As for those models, worker.model_runner.attn_groups is empty.
|
||||
# This change is made during EPD feature development.
|
||||
if (
|
||||
not worker.model_runner.is_pooling_model
|
||||
and worker.model_runner.attn_groups
|
||||
and all(
|
||||
_is_flashinfer_backend(group.backend)
|
||||
for groups in worker.model_runner.attn_groups
|
||||
for group in groups
|
||||
)
|
||||
):
|
||||
logger.info("Warming up FlashInfer attention.")
|
||||
# Warmup with mixed batch containing both prefill and decode tokens
|
||||
|
||||
@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
ECConnectorMetadata = object
|
||||
KVConnectorMetadata = object
|
||||
LoRARequest = object
|
||||
MultiModalFeatureSpec = object
|
||||
@ -188,6 +190,9 @@ class SchedulerOutput:
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
|
||||
@ -7,6 +7,11 @@ from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorMetadata,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
|
||||
self.kv_events_config,
|
||||
self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
self.ec_connector = None
|
||||
if self.vllm_config.ec_transfer_config is not None:
|
||||
self.ec_connector = ECConnectorFactory.create_connector(
|
||||
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
|
||||
)
|
||||
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input: list[int] = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
request.num_computed_tokens,
|
||||
@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
# KVTransfer: loading remote KV, do not allocate for new work.
|
||||
@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
num_computed_tokens,
|
||||
@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
# Allocate for external load encoder cache
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# Build the connector meta for ECConnector
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens: int,
|
||||
num_new_tokens: int,
|
||||
encoder_compute_budget: int,
|
||||
) -> tuple[list[int], int, int]:
|
||||
) -> tuple[list[int], int, int, list[int]]:
|
||||
"""
|
||||
Determine which encoder inputs need to be scheduled in the current step,
|
||||
and update `num_new_tokens` and encoder token budget accordingly.
|
||||
@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
|
||||
in this step, i.e.,
|
||||
[num_computed_tokens, num_computed_tokens + num_new_tokens).
|
||||
- It is not already computed and stored in the encoder cache.
|
||||
- It is not exist on remote encoder cache (via ECConnector)
|
||||
- There is sufficient encoder token budget to process it.
|
||||
- The encoder cache has space to store it.
|
||||
|
||||
@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
|
||||
blocks and externally cached blocks (via KVConnector).
|
||||
"""
|
||||
if num_new_tokens == 0 or not request.has_encoder_inputs:
|
||||
return [], num_new_tokens, encoder_compute_budget
|
||||
return [], num_new_tokens, encoder_compute_budget, []
|
||||
encoder_inputs_to_schedule: list[int] = []
|
||||
mm_features = request.mm_features
|
||||
assert mm_features is not None
|
||||
assert len(mm_features) > 0
|
||||
external_load_encoder_input = []
|
||||
|
||||
# Check remote cache first
|
||||
if self.ec_connector is not None:
|
||||
remote_cache_has_item = self.ec_connector.has_caches(request)
|
||||
# NOTE: since scheduler operates on the request level (possibly with
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
)
|
||||
|
||||
def get_grammar_bitmask(
|
||||
|
||||
@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
else:
|
||||
@ -136,6 +138,13 @@ class KVConnectorOutput:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECConnectorOutput:
|
||||
# [mm_hash]
|
||||
finished_sending: set[str] | None = None
|
||||
finished_recving: set[str] | None = None
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
@ -167,6 +176,8 @@ class ModelRunnerOutput:
|
||||
|
||||
kv_connector_output: KVConnectorOutput | None = None
|
||||
|
||||
ec_connector_output: ECConnectorOutput | None = None
|
||||
|
||||
# req_id -> num_nans_in_logits
|
||||
num_nans_in_logits: dict[str, int] | None = None
|
||||
|
||||
@ -192,6 +203,41 @@ class DraftTokenIds:
|
||||
draft_token_ids: list[list[int]]
|
||||
|
||||
|
||||
def make_empty_encoder_model_runner_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
"""
|
||||
Create a ModelRunnerOutput stub that contains the correct
|
||||
per-request bookkeeping but no generated data yet.
|
||||
"""
|
||||
if not scheduler_output.num_scheduled_tokens:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# Convert to list so we get a deterministic, indexable sequence
|
||||
req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
|
||||
# Give every request its own contiguous index
|
||||
req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)}
|
||||
|
||||
# No tokens generated yet ⇒ one empty list per request
|
||||
sampled_token_ids: list[list[int]] = [[0] for _ in req_ids]
|
||||
|
||||
# Pooler outputs are not available yet ⇒ use None placeholders
|
||||
pooler_output: list[torch.Tensor | None] = [None for _ in req_ids]
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
kv_connector_output=None,
|
||||
ec_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=[],
|
||||
req_id_to_index={},
|
||||
|
||||
87
vllm/v1/worker/ec_connector_model_runner_mixin.py
Normal file
87
vllm/v1/worker/ec_connector_model_runner_mixin.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Define EC connector functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING, # noqa: UP035
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import ECConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU)
|
||||
class ECConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def maybe_save_ec_to_connector(
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
mm_hash: str,
|
||||
):
|
||||
if not has_ec_transfer():
|
||||
logger.debug("Not have ec transfer please check")
|
||||
return
|
||||
connector = get_ec_transfer()
|
||||
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
|
||||
|
||||
@staticmethod
|
||||
def get_finished_ec_transfers(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if has_ec_transfer():
|
||||
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def maybe_get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
) -> AbstractContextManager[ECConnectorOutput | None]:
|
||||
return (
|
||||
ECConnectorModelRunnerMixin._get_ec_connector_output(
|
||||
scheduler_output, encoder_cache, **kwargs
|
||||
)
|
||||
if has_ec_transfer()
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire EC conector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _get_ec_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
encoder_cache: dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
) -> Generator[ECConnectorOutput, None, None]:
|
||||
output = ECConnectorOutput()
|
||||
|
||||
ec_connector = get_ec_transfer()
|
||||
assert isinstance(ec_connector, ECConnectorBase)
|
||||
assert scheduler_output.ec_connector_metadata is not None
|
||||
ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata)
|
||||
|
||||
if not ec_connector.is_producer:
|
||||
ec_connector.start_load_caches(encoder_cache, **kwargs)
|
||||
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
output.finished_sending, output.finished_recving = (
|
||||
ec_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
|
||||
ec_connector.clear_connector_metadata()
|
||||
@ -35,6 +35,7 @@ from vllm.config import (
|
||||
get_layers_from_vllm_config,
|
||||
update_config,
|
||||
)
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
@ -114,12 +115,14 @@ from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
AsyncModelRunnerOutput,
|
||||
DraftTokenIds,
|
||||
ECConnectorOutput,
|
||||
KVConnectorOutput,
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
PoolerOutput,
|
||||
SamplerOutput,
|
||||
make_empty_encoder_model_runner_output,
|
||||
)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple):
|
||||
sample_hidden_states: torch.Tensor
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
kv_connector_output: KVConnectorOutput | None
|
||||
ec_connector_output: ECConnectorOutput | None
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
class GPUModelRunner(
|
||||
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
@ -2191,20 +2200,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
torch.Tensor,
|
||||
IntermediateTensors | None,
|
||||
dict[str, Any],
|
||||
ECConnectorOutput | None,
|
||||
]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
is_first_rank = get_pp_group().is_first_rank
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
ec_connector_output = None
|
||||
|
||||
if (
|
||||
self.supports_mm_inputs
|
||||
and is_first_rank
|
||||
and not self.model_config.is_encoder_decoder
|
||||
):
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
|
||||
with self.maybe_get_ec_connector_output(
|
||||
scheduler_output,
|
||||
encoder_cache=self.encoder_cache,
|
||||
) as ec_connector_output:
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
|
||||
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
)
|
||||
|
||||
def _sample(
|
||||
@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Update persistent batch states.
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
with self.maybe_get_ec_connector_output(
|
||||
scheduler_output,
|
||||
encoder_cache=self.encoder_cache,
|
||||
) as ec_connector_output:
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
return make_empty_encoder_model_runner_output(scheduler_output)
|
||||
|
||||
if not num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if no work to do.
|
||||
@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
) = self._preprocess(
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
ec_connector_output,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
ec_connector_output,
|
||||
) = self.execute_model_state
|
||||
# Clear ephemeral state.
|
||||
self.execute_model_state = None
|
||||
@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=kv_connector_output,
|
||||
ec_connector_output=ec_connector_output
|
||||
if self.supports_mm_inputs
|
||||
else None,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
@ -887,3 +888,7 @@ def init_worker_distributed_environment(
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size,
|
||||
)
|
||||
|
||||
# Init ec connector here before KV caches caches init
|
||||
# NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode
|
||||
ensure_ec_transfer_initialized(vllm_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user