diff --git a/docs/assets/features/disagg_encoder/disagg_encoder_flow.png b/docs/assets/features/disagg_encoder/disagg_encoder_flow.png new file mode 100644 index 0000000000000..2951468c11d9a Binary files /dev/null and b/docs/assets/features/disagg_encoder/disagg_encoder_flow.png differ diff --git a/docs/features/disagg_encoder.md b/docs/features/disagg_encoder.md new file mode 100644 index 0000000000000..7d40af7069822 --- /dev/null +++ b/docs/features/disagg_encoder.md @@ -0,0 +1,75 @@ +# Disaggregated Encoder + +A **disaggregated encoder** runs the vision-encoder stage of a multimodal LLM in a process that is separate from the pre-fill / decoder stage. Deploying these two stages in independent vLLM instances brings three practical benefits: + +1. **Independent, fine-grained scaling** +2. **Lower time-to-first-token (TTFT)** +3. **Cross-process reuse and caching of encoder outputs** + +Design doc: + +--- + +## 1 Motivation + +### 1. Independent, fine-grained scaling + +* Vision encoders are lightweight, while language models are orders of magnitude larger. +* The language model can be parallelised without affecting the encoder fleet. +* Encoder nodes can be added or removed independently. + +### 2. Lower time-to-first-token (TTFT) + +* Language-only requests bypass the vision encoder entirely. +* Encoder output is injected only at required attention layers, shortening the pre-fill critical path. + +### 3. Cross-process reuse and caching + +* In-process encoders confine reuse to a single worker. +* A remote, shared cache lets any worker retrieve existing embeddings, eliminating redundant computation. + +--- + +## 2 Usage Example + +The current reference pathway is **SharedStorageConnector**. +Below ready-to-run scripts shows the workflow: + +1 Encoder instance + 1 PD instance: +`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_encoder_example.sh` + +1 Encoder instance + 1 Prefill instance + 1 Decode instance: +`examples/online_serving/disaggregated_encoder/shared_storage_connector/disagg_epd_example.sh` + +--- + +## 3 Test Script + +Please refer to the directories `tests/v1/ec_connector` + +## 4 Development + +Disaggregated encoding is implemented by running two parts: + +* **Encoder instance** – a vLLM instance to performs vision encoding. +* **Prefill/Decode (PD) instance(s)** – runs language pre-fill and decode. + * PD can be in either a single normal instance with `disagg_encoder_example.sh` (E->PD) or in disaggregated instances with `disagg_epd_example.sh` (E->P->D) + +A connector transfers encoder-cache (EC) embeddings from the encoder instance to the PD instance. +All related code is under `vllm/distributed/ec_transfer`. + +### Key abstractions + +* **ECConnector** – interface for retrieving EC caches produced by the encoder. + * *Scheduler role* – checks cache existence and schedules loads. + * *Worker role* – loads the embeddings into memory. + +Here is a figure illustrating disaggregate encoder flow: + +![Disaggregated Encoder Flow](../assets/features/disagg_encoder/disagg_encoder_flow.png) + +For the PD disaggregation part, the Prefill instance receive cache exactly the same as the disaggregate encoder flow above. Prefill instance executes 1 step (prefill -> 1 token output) and then transfer KV cache to the Decode instance for the remaining execution. The KV transfer part purely happens after the execute of the PDinstance. + +`docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0) + +We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D; diff --git a/examples/online_serving/disaggregated_encoder/README.md b/examples/online_serving/disaggregated_encoder/README.md new file mode 100644 index 0000000000000..5813a3cecf73b --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/README.md @@ -0,0 +1,119 @@ +# Disaggregated Encoder + +These example scripts that demonstrate the disaggregated encoder (EPD) features of vLLM. + +For a detailed explanation of the EPD features, please refer to the [Disaggregated Encoder Feature Documentation](../../../docs/features/disagg_encoder.md). + +## Files + +- `disagg_epd_proxy.py` - Proxy script that demonstrates the XeYpZd setup (X encode instances, Y prefill instances, Z decode instances). Currently stable for the 1e1p1d configuration. + +- `disagg_1e1p1d_example.sh` - Sets up the 1e1p1d configuration, runs the VisionArena benchmark, and processes a single request with a local image. + +- `disagg_1e1pd_example.sh` - Sets up the 1e1pd configuration, runs the VisionArena benchmark, and processes a single request with a local image. + +### Custom Configuration + +```bash +# Use specific GPUs +GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash disagg_1e1p1d_example.sh + +# Use specific ports +ENDPOINT_PORT=10001 bash disagg_1e1p1d_example.sh + +# Use specific model +MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash disagg_1e1p1d_example.sh + +# Use specific storage path +EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash disagg_1e1p1d_example.sh +``` + +## Encoder Instances + +Encoder engines should be launched with the following flags: + +- `--enforce-eager` **(required)** – The current EPD implementation is only compatible with encoder instances running in this mode. + +- `--no-enable-prefix-caching` **(required)** – Encoder instances do not consume KV cache; prefix caching is disabled to avoid conflicts with other features. + +- `--max-num-batched-tokens=` **(default: 2048)** – This flag controls the token scheduling budget per decoding step and is irrelevant to encoder-only instances. **Set it to a very high value (effectively unlimited) to bypass scheduler limitations.** The actual token budget is managed by the encoder cache manager. + +## Local media inputs + +To support local image inputs (from your ```MEDIA_PATH``` directory), add the following flag to the encoder instance: + +```bash +--allowed-local-media-path $MEDIA_PATH +``` + +The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url": "file://'"$MEDIA_PATH_FILENAME"'}``` as multimodal inputs. Each URI is passed unchanged from the `disagg_encoder_proxy` to the encoder instance so that the encoder can load the media locally. + +## EC connector and KV transfer + +The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration: + +```bash +# Add to encoder instance: +--ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } +}' + +# Add to prefill/prefill+decode instance: +--ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } +}' +``` + +`$EC_SHARED_STORAGE_PATH` is the path where the EC connector temporarily stores the cache. + +If you enable prefill instance (`--prefill-servers-urls` not disabled), you will need --kv-transfer-config to facilitate the PD disaggregation. Currently, we use the `NixlConnector` for this purpose. Refer to `tests/v1/kv_connector/nixl_integration` for more example codes on PD disaggregation with Nixl. + +```bash +# Add to prefill instance: +--kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_producer" +}' + +# Add to decode instance: +--kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer" +}' +``` + +## Proxy Instance Flags (`disagg_epd_proxy.py`) + +| Flag | Description | +|------|-------------| +| `--encode-servers-urls` | Comma-separated list of encoder endpoints. Every multimodal item extracted from the request is fanned out to one of these URLs in a round-robin fashion. | +| `--prefill-servers-urls` | Comma-separated list of prefill endpoints. Set to `disable`, `none`, or `""` to skip the dedicated prefill phase and run E+PD (encoder + combined prefill/decode). | +| `--decode-servers-urls` | Comma-separated list of decode endpoints. Non-stream and stream paths both round-robin over this list. | +| `--host`, `--port` | Bind address for the proxy itself (defaults: `0.0.0.0:8000`). | + +Example usage: +For E + PD setup: + +```bash +$ python disagg_encoder_proxy.py \ + --encode-servers-urls "http://e1:8001,http://e2:8002" \ + --prefill-servers-urls "disable" \ + --decode-servers-urls "http://pd1:8003,http://pd2:8004" +``` + +For E + P + D setup: + +```bash +$ python disagg_encoder_proxy.py \ + --encode-servers-urls "http://e1:8001,http://e2:8001" \ + --prefill-servers-urls "http://p1:8003,http://p2:8004" \ + --decode-servers-urls "http://d1:8005,http://d2:8006" +``` diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh new file mode 100644 index 0000000000000..57489df64f51e --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh @@ -0,0 +1,221 @@ +#!/bin/bash +set -euo pipefail + +declare -a PIDS=() + +############################################################################### +# Configuration -- override via env before running +############################################################################### +MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}" +LOG_PATH="${LOG_PATH:-./logs}" +mkdir -p $LOG_PATH + +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_PORT="${PREFILL_PORT:-19535}" +DECODE_PORT="${DECODE_PORT:-19536}" +PROXY_PORT="${PROXY_PORT:-10001}" + +GPU_E="${GPU_E:-2}" +GPU_P="${GPU_P:-2}" +GPU_D="${GPU_D:-3}" + +EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}" +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout + +NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark + +export UCX_TLS=all +export UCX_NET_DEVICES=all + +############################################################################### +# Helpers +############################################################################### +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +START_TIME=$(date +"%Y%m%d_%H%M%S") +ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log +P_LOG=$LOG_PATH/p_${START_TIME}.log +D_LOG=$LOG_PATH/d_${START_TIME}.log +PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log + +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:$port/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Stopping everything…" + trap - INT TERM USR1 # prevent re-entrancy + + # Kill all tracked PIDs + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" 2>/dev/null + fi + done + + # Wait a moment for graceful shutdown + sleep 2 + + # Force kill any remaining processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -9 "$pid" 2>/dev/null + fi + done + + # Kill the entire process group as backup + kill -- -$$ 2>/dev/null + + echo "All processes stopped." + exit 0 +} + +trap cleanup INT +trap cleanup USR1 +trap cleanup TERM + +# clear previous cache +echo "remove previous ec cache folder" +rm -rf $EC_SHARED_STORAGE_PATH + +echo "make ec cache folder" +mkdir -p $EC_SHARED_STORAGE_PATH + +############################################################################### +# Encoder worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.01 \ + --port "$ENCODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 114688 \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${ENC_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Prefill worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_P" \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \ +vllm serve "$MODEL" \ + --gpu-memory-utilization 0.7 \ + --port "$PREFILL_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_producer" + }' \ + >"${P_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Decode worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_D" \ +UCX_NET_DEVICES=all \ +VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \ +vllm serve "$MODEL" \ + --gpu-memory-utilization 0.7 \ + --port "$DECODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer" + }' \ + >"${D_LOG}" 2>&1 & + +PIDS+=($!) + +# Wait for workers +wait_for_server $ENCODE_PORT +wait_for_server $PREFILL_PORT +wait_for_server $DECODE_PORT + +############################################################################### +# Proxy +############################################################################### +python disagg_epd_proxy.py \ + --host "0.0.0.0" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "http://localhost:$PREFILL_PORT" \ + --decode-servers-urls "http://localhost:$DECODE_PORT" \ + >"${PROXY_LOG}" 2>&1 & + +PIDS+=($!) + +wait_for_server $PROXY_PORT +echo "All services are up!" + +############################################################################### +# Benchmark +############################################################################### +echo "Running benchmark (stream)..." +vllm bench serve \ + --model $MODEL \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --seed 0 \ + --num-prompts $NUM_PROMPTS \ + --port $PROXY_PORT + +PIDS+=($!) + +############################################################################### +# Single request with local image +############################################################################### +echo "Running single request with local image (non-stream)..." +curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'${MODEL}'", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}}, + {"type": "text", "text": "What is in this image?"} + ]} + ] + }' + + +# cleanup +echo "cleanup..." +cleanup \ No newline at end of file diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh new file mode 100644 index 0000000000000..6073e0580b11d --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh @@ -0,0 +1,186 @@ +#!/bin/bash +set -euo pipefail + +declare -a PIDS=() + +############################################################################### +# Configuration -- override via env before running +############################################################################### +MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}" +LOG_PATH="${LOG_PATH:-./logs}" +mkdir -p $LOG_PATH + +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19535}" +PROXY_PORT="${PROXY_PORT:-10001}" + +GPU_E="${GPU_E:-0}" +GPU_PD="${GPU_PD:-1}" + +EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache}" +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-12000}" # wait_for_server timeout + +NUM_PROMPTS="${NUM_PROMPTS:-100}" # number of prompts to send in benchmark + +############################################################################### +# Helpers +############################################################################### +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +START_TIME=$(date +"%Y%m%d_%H%M%S") +ENC_LOG=$LOG_PATH/encoder_${START_TIME}.log +PD_LOG=$LOG_PATH/pd_${START_TIME}.log +PROXY_LOG=$LOG_PATH/proxy_${START_TIME}.log + +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:$port/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Stopping everything…" + trap - INT TERM USR1 # prevent re-entrancy + + # Kill all tracked PIDs + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill "$pid" 2>/dev/null + fi + done + + # Wait a moment for graceful shutdown + sleep 2 + + # Force kill any remaining processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -9 "$pid" 2>/dev/null + fi + done + + # Kill the entire process group as backup + kill -- -$$ 2>/dev/null + + echo "All processes stopped." + exit 0 +} + +trap cleanup INT +trap cleanup USR1 +trap cleanup TERM + +# clear previous cache +echo "remove previous ec cache folder" +rm -rf $EC_SHARED_STORAGE_PATH + +echo "make ec cache folder" +mkdir -p $EC_SHARED_STORAGE_PATH + +############################################################################### +# Encoder worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.01 \ + --port "$ENCODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 114688 \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${ENC_LOG}" 2>&1 & + +PIDS+=($!) + +############################################################################### +# Prefill+Decode worker +############################################################################### +CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ + --gpu-memory-utilization 0.7 \ + --port "$PREFILL_DECODE_PORT" \ + --enforce-eager \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + >"${PD_LOG}" 2>&1 & + +PIDS+=($!) + +# Wait for workers +wait_for_server $ENCODE_PORT +wait_for_server $PREFILL_DECODE_PORT + +############################################################################### +# Proxy +############################################################################### +python disagg_epd_proxy.py \ + --host "0.0.0.0" \ + --port "$PROXY_PORT" \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "disable" \ + --decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \ + >"${PROXY_LOG}" 2>&1 & + +PIDS+=($!) + +wait_for_server $PROXY_PORT +echo "All services are up!" + +############################################################################### +# Benchmark +############################################################################### +echo "Running benchmark (stream)..." +vllm bench serve \ + --model $MODEL \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --seed 0 \ + --num-prompts $NUM_PROMPTS \ + --port $PROXY_PORT + +PIDS+=($!) + +############################################################################### +# Single request with local image +############################################################################### +echo "Running single request with local image (non-stream)..." +curl http://127.0.0.1:${PROXY_PORT}/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'${MODEL}'", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "file://'"${GIT_ROOT}"'/tests/v1/ec_connector/integration/hato.jpg"}}, + {"type": "text", "text": "What is in this image?"} + ]} + ] + }' + + +# cleanup +echo "cleanup..." +cleanup \ No newline at end of file diff --git a/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py b/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py new file mode 100644 index 0000000000000..b5f99683c2bf3 --- /dev/null +++ b/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py @@ -0,0 +1,606 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +disagg_encoder_proxy.py + +Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two +clusters: + • encode (multimodal feature extraction) + • decode (language-model inference) + +For MM input we: + 1. Extract *every* image/audio item. + 2. Fire N concurrent requests to the encoder cluster + (one request per item, with **all text removed**). + 3. Wait for all of them to succeed. + 4. Forward the *original* request to a decode server. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import random +import uuid +from collections.abc import AsyncIterator + +import aiohttp +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +############################################################################### +# FastAPI app & global state +############################################################################### + +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s" +) +logger = logging.getLogger("proxy") + +app = FastAPI() +encode_session: aiohttp.ClientSession | None = None +prefill_session: aiohttp.ClientSession | None = None +decode_session: aiohttp.ClientSession | None = None + +############################################################################### +# Utils +############################################################################### + + +MM_TYPES = {"image_url", "audio_url", "input_audio"} + + +def extract_mm_items(request_data: dict) -> list[dict]: + """ + Return *all* image/audio items that appear anywhere in `messages`. + + Each returned dict looks like: + { "type": "image_url", "image_url": {...} } + """ + items: list[dict] = [] + for msg in request_data.get("messages", []): + content = msg.get("content") + if not isinstance(content, list): + continue + + for item in content: + if item.get("type") in MM_TYPES: + items.append(item) + return items + + +async def fanout_encoder_primer( + orig_request: dict, + e_urls: list[str], + req_id: str, +) -> None: + """ + 1. Build one request *per MM item* with all text removed. + 2. Send them concurrently to the encode cluster. + 3. Raise if any of them fails. + """ + logger.info("[%s] Processing multimodal items...", req_id) + + mm_items = extract_mm_items(orig_request) + if not mm_items: + logger.info("[%s] No multimodal items, skipping encoder", req_id) + return # nothing to do + + logger.info("[%s] got %d multimodal items...", req_id, len(mm_items)) + + tasks = [] + + # Round-robin over encode servers to distribute load a bit + url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items))) + + for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)): + # Derive a *child* request id: :: + child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}" + headers = {"x-request-id": child_req_id} + + encoder_req = { + # You *may* need to keep additional fields + "model": orig_request.get("model"), + "messages": [ + {"role": "user", "content": [item]}, + ], + # Only need 1 token so the server actually runs the encoder path + "max_tokens": 1, + "stream": False, + } + tasks.append( + encode_session.post( + f"{target_url}/v1/chat/completions", + json=encoder_req, + headers=headers, + ) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Fail fast if any sub-request failed + for idx, r in enumerate(results): + if isinstance(r, Exception): + logger.error( + "[%s] Encoder request #%d raised exception: %s", + req_id, + idx, + r, + exc_info=r, + ) + raise HTTPException( + status_code=502, detail=f"Encoder request failed: {str(r)}" + ) + if r.status != 200: + try: + detail = await r.text() + except Exception: + detail = "" + logger.error( + "[%s] Encoder request #%d returned status %s: %s", + req_id, + idx, + r.status, + detail, + ) + raise HTTPException( + status_code=r.status, + detail=f"Encoder request failed: {detail}", + ) + + logger.info( + "[%s] All %d encoder requests completed successfully", req_id, len(mm_items) + ) + + +async def maybe_prefill( + req_data: dict, + p_url: str, + req_id: str, +) -> dict: + """ + - Do prefill-only task if p_url exist; + - Return modified request data with kv transfer params (for nixl connector) + - Else, skip and return the original request data for decode + """ + if p_url: + logger.info("[%s] Processing through prefill: %s", req_id, p_url) + + prefill_response = await process_prefill_stage(req_data, p_url, req_id) + # for nixl connector to facilitate kv transfer... + prefill_response_json = await prefill_response.json() + kv_transfer_params = prefill_response_json.get("kv_transfer_params", {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + return req_data + else: + return req_data + + +async def process_prefill_stage( + req_data: dict, + p_url: str, + req_id: str, +) -> dict: + """Process request through Prefill stage and return kv_transfer_params""" + logger.info("[%s] Sending prefill request to: %s", req_id, p_url) + + prefill_request = req_data.copy() + prefill_request["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + prefill_request["stream"] = False + prefill_request["max_tokens"] = 1 + if "max_completion_tokens" in prefill_request: + prefill_request["max_completion_tokens"] = 1 + if "stream_options" in prefill_request: + del prefill_request["stream_options"] + + headers = {"x-request-id": req_id} + try: + prefill_response = await prefill_session.post( + f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers + ) + prefill_response.raise_for_status() + + if prefill_response.status != 200: + error_text = await prefill_response.text() + logger.error( + "[%s] Prefill request failed with status %d: %s", + req_id, + prefill_response.status, + error_text, + ) + raise HTTPException( + status_code=prefill_response.status, + detail={"error": "Prefill request failed", "message": error_text}, + ) + logger.info("[%s] Prefill request completed successfully", req_id) + + return prefill_response + + except Exception as e: + logger.error("Prefill processing failed: %s", str(e)) + raise HTTPException( + status_code=500, + detail={"error": "Prefill processing error", "message": str(e)}, + ) from e + + +############################################################################### +# Middleware for request/response logging +############################################################################### + + +@app.middleware("http") +async def log_requests(request: Request, call_next): + """Middleware to log all incoming requests and responses""" + req_id = request.headers.get("x-request-id", str(uuid.uuid4())) + + # Log incoming request + logger.info( + ">>> [%s] %s %s from %s", + req_id, + request.method, + request.url.path, + request.client.host if request.client else "unknown", + ) + + try: + # Process request + response = await call_next(request) + + # Log response + logger.info( + "<<< [%s] %s %s completed with status %d", + req_id, + request.method, + request.url.path, + response.status_code, + ) + + return response + except Exception as e: + # Log errors + logger.exception( + "!!! [%s] %s %s failed with error: %s", + req_id, + request.method, + request.url.path, + str(e), + ) + raise + + +############################################################################### +# FastAPI lifecycle +############################################################################### + + +@app.on_event("startup") +async def on_startup() -> None: + global encode_session, prefill_session, decode_session + timeout = aiohttp.ClientTimeout(total=100_000) + connector = aiohttp.TCPConnector(limit=0, force_close=False) + encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector) + if app.state.p_urls: + # only setup if prefill instance(s) exist + prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector) + decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector) + + +@app.on_event("shutdown") +async def on_shutdown() -> None: + global encode_session, prefill_session, decode_session + if encode_session: + await encode_session.close() + if prefill_session: + await prefill_session.close() + if decode_session: + await decode_session.close() + + +############################################################################### +# Core forwarding +############################################################################### + + +async def forward_non_stream( + req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str +) -> dict: + try: + # Step 1: Process through Encoder instance (if has MM input) + await fanout_encoder_primer(req_data, e_urls, req_id) + + # Step 2: Process through Prefill instance + req_data = await maybe_prefill(req_data, p_url, req_id) + + # Step 3: Process through Decode instance + logger.info("[%s] Forwarding to decode: %s", req_id, d_url) + headers = {"x-request-id": req_id} + + # Non-streaming response + async with decode_session.post( + f"{d_url}/v1/chat/completions", json=req_data, headers=headers + ) as resp: + resp.raise_for_status() + return await resp.json() + + except HTTPException: + raise + except Exception as e: + logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e)) + raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e + + +async def forward_stream( + req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str +) -> AsyncIterator[str]: + try: + # Step 1: Process through Encoder instance (if has MM input) + await fanout_encoder_primer(req_data, e_urls, req_id) + + # Step 2: Process through Prefill instance + req_data = await maybe_prefill(req_data, p_url, req_id) + + # Step 3: Process through Decode instance + logger.info("[%s] Starting streaming from decode: %s", req_id, d_url) + headers = {"x-request-id": req_id} + + # Streaming response + async with decode_session.post( + f"{d_url}/v1/chat/completions", + json=req_data, + headers=headers, + ) as resp: + resp.raise_for_status() + async for chunk in resp.content.iter_chunked(1024): + if chunk: + yield chunk.decode("utf-8", errors="ignore") + + logger.info("[%s] Streaming completed", req_id) + + except HTTPException: + logger.exception("[%s] HTTPException in forward_stream", req_id) + raise + except Exception as e: + logger.exception("[%s] Error in forward_stream: %s", req_id, str(e)) + raise HTTPException( + status_code=500, detail=f"Proxy streaming error: {str(e)}" + ) from e + + +############################################################################### +# Public routes +############################################################################### + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + try: + req_data = await request.json() + req_id = request.headers.get("x-request-id", str(uuid.uuid4())) + + e_urls = app.state.e_urls # we want the full list for fan-out + p_url = random.choice(app.state.p_urls) if app.state.p_urls else None + d_url = random.choice(app.state.d_urls) + + is_streaming = req_data.get("stream", False) + + if is_streaming: + return StreamingResponse( + forward_stream(req_data, req_id, e_urls, p_url, d_url), + media_type="text/event-stream", + ) + result = await forward_non_stream(req_data, req_id, e_urls, p_url, d_url) + return JSONResponse(content=result) + + except HTTPException: + raise + except Exception as e: + logger.exception("Error in chat_completions endpoint: %s", str(e)) + raise HTTPException( + status_code=500, detail=f"Request processing error: {str(e)}" + ) from e + + +@app.get("/v1/models") +async def list_models(): + async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp: + resp.raise_for_status() + return await resp.json() + + +@app.get("/health") +async def health_check(): + async def healthy(urls): + if not urls: + return "empty" + for u in urls: + try: + async with encode_session.get(f"{u}/health") as resp: + resp.raise_for_status() + except Exception: + return "unhealthy" + return "healthy" + + e_status, p_status, d_status = await asyncio.gather( + healthy(app.state.e_urls), healthy(app.state.p_urls), healthy(app.state.d_urls) + ) + + overall_healthy = all( + status != "unhealthy" for status in (e_status, p_status, d_status) + ) + + status_code = 200 if overall_healthy else 503 + + return JSONResponse( + { + "proxy": "healthy", + "encode_cluster": e_status, + "prefill_cluster": p_status, + "decode_cluster": d_status, + }, + status_code=status_code, + ) + + +############################################################################### +# Simple profiler fan-out (unchanged except for sessions) +############################################################################### + + +async def _post_if_available( + session: aiohttp.ClientSession, + url: str, + payload: dict, + headers: dict, +) -> dict | None: + """ + POST `payload` to `url`. + + Returns + ------- + • The decoded JSON body on success (2xx) + • None if the endpoint does not exist (404) + • Raises for anything else. + """ + try: + resp = await session.post(url, json=payload, headers=headers) + if resp.status == 404: # profiling disabled on that server + logger.warning("Profiling endpoint missing on %s", url) + return None + resp.raise_for_status() + return await resp.json(content_type=None) + except aiohttp.ClientResponseError as exc: + # Pass 404 through the branch above, re-raise everything else + if exc.status == 404: + logger.warning("Profiling endpoint missing on %s", url) + return None + raise + except Exception: + # Network errors etc.: propagate + raise + + +async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str): + """ + Fire & forget to both clusters, tolerate 404. + """ + headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"} + + encode_task = _post_if_available( + encode_session, f"{e_url}/{cmd}_profile", payload, headers + ) + prefill_task = ( + _post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers) + if p_url is not None + else asyncio.sleep(0) + ) + decode_task = _post_if_available( + decode_session, f"{d_url}/{cmd}_profile", payload, headers + ) + + encode_res, prefill_res, decode_res = await asyncio.gather( + encode_task, prefill_task, decode_task + ) + + # If *all* clusters said “I don’t have that route”, surface an error + if encode_res is prefill_res is decode_res is None: + raise HTTPException( + status_code=503, + detail="Profiling endpoints are disabled on all clusters", + ) + + return { + "encode": encode_res, # may be None + "prefill": prefill_res, # may be None + "decode": decode_res, # may be None + } + + +@app.post("/start_profile") +async def start_profile(request: Request): + body = await request.json() + # TODO: handle multi urls properly + e_url = random.choice(app.state.e_urls) + p_url = random.choice(app.state.p_urls) if app.state.p_urls else None + d_url = random.choice(app.state.d_urls) + return await _profile_cmd("start", body, e_url, p_url, d_url) + + +@app.post("/stop_profile") +async def stop_profile(request: Request): + body = await request.json() + # TODO: handle multi urls properly + e_url = random.choice(app.state.e_urls) + p_url = random.choice(app.state.p_urls) if app.state.p_urls else None + d_url = random.choice(app.state.d_urls) + return await _profile_cmd("stop", body, e_url, p_url, d_url) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--encode-servers-urls", + required=True, + help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")', + ) + parser.add_argument( + "--prefill-servers-urls", + required=True, + help=( + 'Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") ', + 'to enable E->P->D, set "disable" or "none" to enable E->PD', + ), + ) + parser.add_argument( + "--decode-servers-urls", + required=True, + help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")', + ) + + args = parser.parse_args() + app.state.e_urls = [ + u.strip() for u in args.encode_servers_urls.split(",") if u.strip() + ] + app.state.d_urls = [ + u.strip() for u in args.decode_servers_urls.split(",") if u.strip() + ] + # handle prefill instances + if args.prefill_servers_urls.lower() in ("disable", "none", ""): + app.state.p_urls = [] + logger.info( + "Disaggregated prefill phase explicitly disabled by user. Running E + PD..." + ) + else: + app.state.p_urls = [ + u.strip() for u in args.prefill_servers_urls.split(",") if u.strip() + ] + logger.info("Disaggregated prefill phase is enabled. Running E + P + D...") + + logger.info("Proxy listening on %s:%s", args.host, args.port) + logger.info("Encode servers: %s", app.state.e_urls) + logger.info("Prefill instances %s", app.state.p_urls) + logger.info("Decode servers: %s", app.state.d_urls) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + loop="uvloop", + access_log=True, + ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 749cf7dc8397e..d5b829e79b8f7 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -8,6 +8,7 @@ import torch from vllm.config import ( CacheConfig, + ECTransferConfig, KVTransferConfig, ModelConfig, SchedulerConfig, @@ -20,6 +21,9 @@ from vllm.multimodal.inputs import ( PlaceholderRange, ) from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.utils.hashing import sha256 +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import ( @@ -872,7 +876,10 @@ def _step_until_done( for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): # We should be in the decode phase now. assert num_scheduled_tokens == 1 - assert len(output.kv_connector_metadata.requests) == 0 + if scheduler.connector is not None: + assert len(output.kv_connector_metadata.requests) == 0 + if scheduler.ec_connector is not None: + assert len(output.ec_connector_metadata.mm_datas) == 0 ecos = scheduler.update_from_output(output, model_runner_output)[0] all_done = True for eco in ecos.outputs: @@ -1066,7 +1073,10 @@ def test_external_prefix_cache_metrics(): assert external_stats.preempted_requests == 0 -def test_kv_connector_unable_to_allocate(): +@pytest.mark.parametrize( + "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] +) +def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): """ Test whether scheduler with KVConnector is able to handle unable to allocate (run out of blocks in allocate_slots(). @@ -1080,6 +1090,9 @@ def test_kv_connector_unable_to_allocate(): use_kv_connector=True, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, + # encoder connector should not affect test results + use_ec_connector=use_ec_connector, + ec_role=ec_role, ) NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") @@ -1148,7 +1161,10 @@ def test_kv_connector_unable_to_allocate(): assert len(scheduler.waiting) == 0 -def test_kv_connector_handles_preemption(): +@pytest.mark.parametrize( + "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] +) +def test_kv_connector_handles_preemption(use_ec_connector, ec_role): """ Test whether scheduler with KVConnector is able to handle unable to allocate (run out of blocks in allocate_slots(). @@ -1163,6 +1179,9 @@ def test_kv_connector_handles_preemption(): use_kv_connector=True, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, + # encoder connector should not affect test results + use_ec_connector=use_ec_connector, + ec_role=ec_role, ) NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE @@ -1379,6 +1398,8 @@ def create_scheduler_with_priority( block_size: int = 16, max_model_len: int | None = None, num_speculative_tokens: int | None = None, + use_ec_connector: bool = False, + ec_role: str | None = None, ) -> Scheduler: """Create scheduler with priority policy enabled. @@ -1439,12 +1460,23 @@ def create_scheduler_with_priority( model="ngram", num_speculative_tokens=num_speculative_tokens ) + ec_transfer_config = ( + ECTransferConfig( + ec_connector="ECSharedStorageConnector", + ec_role=ec_role, + ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, + ) + if use_ec_connector + else None + ) + vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, + ec_transfer_config=ec_transfer_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests @@ -1465,16 +1497,23 @@ def create_scheduler_with_priority( ) +_none_hash_initialized = False + + def create_requests_with_priority( num_requests: int, priorities: list[int], arrival_times: list[float] | None = None, num_tokens: int = 10, + mm_hashes_list: list[list[str]] | None = None, mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, stop_token_ids: list[int] | None = None, prompt_logprobs: int | None = None, starting_idx: int = 0, + same_prompt: bool = False, + block_size: int = 16, + req_ids: list[str] | None = None, ): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests @@ -1483,6 +1522,12 @@ def create_requests_with_priority( else: arrival_times = [float(i) for i in range(num_requests)] + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(sha256) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams( ignore_eos=False, max_tokens=max_tokens, @@ -1490,29 +1535,70 @@ def create_requests_with_priority( prompt_logprobs=prompt_logprobs, ) requests = [] + + if mm_hashes_list is not None: + # NOTE: allow manual input; some mm items can have the same identifier + # no. of mm_hashes and mm_positions for each request should be identical + assert mm_positions is not None, ( + "mm_positions must be provided when mm_hashes_list is provided" + ) + assert len(mm_hashes_list) == len(mm_positions) == num_requests + assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions] + + # Since same identifier would imply they are identical encoder output + # Verify mm items with identical identifier are having mm_position.length + seen_hashes: dict[str, int] = {} + + if req_ids: + assert len(req_ids) == num_requests + else: + req_ids = [f"{i + starting_idx}" for i in range(num_requests)] + for i in range(num_requests): mm_features = [] - if mm_positions is not None: - mm_position = mm_positions[i] - for j, position in enumerate(mm_position): - identifier = f"hash{i}_{j}" - mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), - mm_position=position, - identifier=identifier, - modality="image", - ) - mm_features.append(mm_feature) + for j, position in enumerate( + mm_positions[i] if mm_positions is not None else [] + ): + if mm_hashes_list is not None: + identifier = mm_hashes_list[i][j] + + # Verify if position length is identical + position_length = position.length + if identifier in seen_hashes: + assert seen_hashes[identifier] == position_length, ( + f"mm_hash '{identifier}' has inconsistent position lengths: " + f"previously {seen_hashes[identifier]}, now {position_length} " + f"at request {i}, position {j}" + ) + else: + seen_hashes[identifier] = position_length + else: + # Unique dummy hash for each mm item + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) + + prompt_token_ids = ( + [starting_idx] * num_tokens + if same_prompt + else [i + starting_idx] * num_tokens + ) request = Request( - request_id=f"{i + starting_idx}", - prompt_token_ids=[i + starting_idx] * num_tokens, + request_id=req_ids[i], + prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, arrival_time=arrival_times[i], priority=priorities[i], + block_hasher=block_hasher, ) requests.append(request) return requests @@ -1999,7 +2085,12 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(scheduler.waiting) == 1 -def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): +@pytest.mark.parametrize( + "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] +) +def test_priority_scheduling_preemption_and_resumption_when_out_of_kv( + use_ec_connector, ec_role +): """Test that priority scheduling preempts lower priority requests when out of KV cache space.""" # Create scheduler with very limited memory to force preemption @@ -2009,6 +2100,9 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): num_blocks=5, # Can hold 64 tokens (first block is null) block_size=16, # Standard block size use_kv_connector=True, + # encoder connector should not affect test results + use_ec_connector=use_ec_connector, + ec_role=ec_role, ) # Create a request and schedule it @@ -2168,3 +2262,976 @@ def _validate_chunked_prefill_settings_for_encoder_decoder( assert scheduler_config.disable_chunked_mm_input is not expect_enabled if is_encoder_decoder and not expect_enabled: assert scheduler_config.long_prefill_token_threshold == 0 + + +# ============================================================================== +# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests start +# NOTE: In E->P->D disagg case, both KV and EC Connector works in P instance +# Unless specify, the existence of KV Connector should not affect any test results +# ============================================================================== + + +def _assert_right_encoder_cache_allocated( + scheduler: Scheduler, + hashes_to_check: list[str] | None = None, + requests: list[Request] | None = None, + expected_total_allocated: int | None = None, +): + """Check whether encoder cache is allocated correctly.""" + encoder_cache_manager = scheduler.encoder_cache_manager + + # Verify encoder cache manager exists + assert encoder_cache_manager is not None, "Encoder cache manager should exist" + + # Verify number of cache + if expected_total_allocated is not None: + assert len(encoder_cache_manager.cached) == expected_total_allocated + if expected_total_allocated == 0: + return + + # Verify each request with MM data is in cache + cached_hashes = set(encoder_cache_manager.cached.keys()) + + if hashes_to_check: + missed_hashes = set(hashes_to_check) - cached_hashes + assert not missed_hashes, ( + f"Miss hashes: {missed_hashes} " + f"Existing encoder cache: {encoder_cache_manager.cached}" + ) + + for req in requests if requests is not None else []: + if req.mm_features: + mm_hashes = [f.identifier for f in req.mm_features] + req_hashes = set(mm_hashes) # unique hashes set + missed_hashes = req_hashes - cached_hashes + assert not missed_hashes, ( + f"Miss hashes in cache for request {req.request_id}: {missed_hashes} " + f"Existing encoder cache: {encoder_cache_manager.cached}" + ) + + +def _assert_right_ec_connector_metadata( + output: SchedulerOutput, + mm_features_list: list[MultiModalFeatureSpec], +): + """Verify that ECConnector metadata EXACTLY matches the input MM data""" + # Get the connector metadata + metadata = output.ec_connector_metadata + + # Create lookup dictionaries for efficient access + metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas} + + # Check all required identifiers exist in metadata; and no extra + # In ECSharedStorageConnector format + # NOTE: even having same identifier, the mm_features can be different + # since their mm_position can be in different offsets, etc + identifiers_dict = {f.identifier for f in mm_features_list} + assert set(metadata_dict.keys()) == identifiers_dict + + # Verify the info matches + for i, mm_feature in enumerate(mm_features_list): + identifier = mm_feature.identifier + assert metadata_dict[identifier].mm_hash == identifier + assert metadata_dict[identifier].num_token == mm_feature.mm_position.length + + +def _assert_right_encoder_inputs( + output: SchedulerOutput, + check_exist: bool | None = True, + requests: list[Request] | None = None, + expected_encoder_inputs: list[list[int]] | None = None, + expected_total_reqs: int | None = None, +): + """Verify that requests/mm_hashes should (not) in scheduled encoder input + If check_exist is False, this function returns True + if requests are NOT in encoder inputs""" + + # Get the scheduled encoder inputs + # NOTE: scheduled_encoder_inputs is a dictionary with request id as key + scheduled_encoder_inputs = output.scheduled_encoder_inputs + + # Check if scheduled_encoder_inputs is empty as expected + if expected_total_reqs is not None: + assert len(scheduled_encoder_inputs) == expected_total_reqs + if expected_total_reqs == 0: + return + + # Number of expected enocder inputs should match number of requests + if expected_encoder_inputs: + assert check_exist and requests is not None # only support expect input exist + assert len(requests) == len(expected_encoder_inputs) + + # Check request (not) exist as expected + for i, request in enumerate(requests if requests is not None else []): + assert (request.request_id in scheduled_encoder_inputs) is check_exist, ( + f"Request {request.id} presence mismatch: expected {check_exist}, " + f"got {request.id in scheduled_encoder_inputs}" + ) + if expected_encoder_inputs: + scheduled_encoder_input = scheduled_encoder_inputs[request.request_id] + assert scheduled_encoder_input == expected_encoder_inputs[i] + + +def test_scheduler_no_ec_connector_by_default(): + """Test scheduler doesn't have EC connector by default.""" + scheduler = create_scheduler() + assert scheduler.ec_connector is None + + +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_text_only_request(use_kv_connector): + """Test text-only requests don't allocate encoder cache.""" + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + use_kv_connector=use_kv_connector, + use_ec_connector=True, + ec_role="ec_consumer", + ) + + NUM_PROMPT_TOKENS = 100 + + # Create text-only request (no mm_positions) + requests = create_requests( + num_requests=1, + num_tokens=NUM_PROMPT_TOKENS, + ) + assert not requests[0].mm_features # No MM data + + scheduler.add_request(requests[0]) + output = scheduler.schedule() + + # Should schedule + assert len(output.scheduled_new_reqs) == 1 + + # Scheduled tokens should equal prompt tokens exactly + scheduled = output.num_scheduled_tokens[requests[0].request_id] + assert scheduled == NUM_PROMPT_TOKENS, ( + f"Text-only should schedule {NUM_PROMPT_TOKENS}, got {scheduled}" + ) + + # Encoder cache should be empty + _assert_right_encoder_cache_allocated(scheduler, expected_total_allocated=0) + + # ECConnector should carry no metadata + _assert_right_ec_connector_metadata(output, mm_features_list=[]) + + # Scheduled encoder input should be empty; no mm to compute + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_cache_hit_external_load(use_kv_connector): + """Test ec_consumer loads from external cache when hit. + A normal basic operation for EPD disaggrgation""" + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + enable_prefix_caching=True, + # kv connector should not effect test results + use_kv_connector=use_kv_connector, + use_ec_connector=True, + ec_role="ec_consumer", + ) + + # Create MM request + NUM_TOKENS = 200 # NOTE: includes mm tokens + NUM_ENCODER_TOKENS = 100 + mm_hashes_list = [["hash_test1"]] + mm_positions = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS)]] + + request = create_requests( + num_requests=1, + num_tokens=NUM_TOKENS, + mm_hashes_list=mm_hashes_list, + mm_positions=mm_positions, + )[0] + + # Mock cache hit - encoder cache exists externally + scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.update_state_after_alloc = Mock( + wraps=scheduler.ec_connector.update_state_after_alloc + ) + + scheduler.add_request(request) + output = scheduler.schedule() + + # Should schedule prompt tokens + scheduled_tokens = output.num_scheduled_tokens[request.request_id] + assert scheduled_tokens == NUM_TOKENS + + # Should called update_state_after_alloc for external load + scheduler.ec_connector.update_state_after_alloc.assert_called_with(request, 0) + + # Encoder cache should contain mm items from request + _assert_right_encoder_cache_allocated(scheduler, requests=[request]) + + # ECConnector should carry metadata of request + _assert_right_ec_connector_metadata(output, mm_features_list=request.mm_features) + + # Scheduled encoder input should be empty; no mm to compute + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_cache_miss_computes_locally(use_kv_connector): + """Test consumer can compute encoder locally when cache miss (fallback).""" + # encoder cache itself if it doesn't receive it from external storage + + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + enable_prefix_caching=True, + use_kv_connector=use_kv_connector, + use_ec_connector=True, + ec_role="ec_consumer", + ) + + # Verify consumer role + assert scheduler.ec_connector is not None + assert not scheduler.ec_connector.is_producer + + # Create MM request + request_mm_missed = create_requests( + num_requests=1, + num_tokens=200, # Total (including 100 MM) + mm_positions=[[PlaceholderRange(offset=0, length=100)]], # 100 MM tokens + )[0] + + # Mock cache miss - encoder cache doesn't exist externally + scheduler.ec_connector.has_caches = Mock(return_value=[False]) + + scheduler.add_request(request_mm_missed) + output = scheduler.schedule() + + # SCHEDULER should decide to compute encoder locally (fallback) + assert len(output.scheduled_new_reqs) == 1 + + # Should schedule full prompt tokens + scheduled_tokens = output.num_scheduled_tokens[request_mm_missed.request_id] + assert scheduled_tokens == 200, ( + f"Expected 200 tokens on cache miss, got {scheduled_tokens}" + ) + + # Encoder cache should contain mm items from request + _assert_right_encoder_cache_allocated(scheduler, requests=[request_mm_missed]) + + # ECConnector should carry no metadata (missed cache) + _assert_right_ec_connector_metadata(output, mm_features_list=[]) + + # Scheduled encoder input contain mm for request_mm_missed + _assert_right_encoder_inputs( + output, + requests=[request_mm_missed], + expected_encoder_inputs=[[0]], # index 0 of the mm item + expected_total_reqs=1, + ) + + # Then MODEL_RUNNER will execute the encoder and cache the result + + +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): + """Test consumer with partial cache hit (local & connector) with 2 requests.""" + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + enable_prefix_caching=True, + use_kv_connector=use_kv_connector, + use_ec_connector=True, + ec_role="ec_consumer", + ) + + # Create MM request + NUM_TOKENS_1 = 300 # NOTE: includes mm tokens + NUM_ENCODER_TOKENS_1 = 50 + mm_hashes_list_1 = [["hash1_A", "hash1_B", "hash1_A", "hash1_F"]] + mm_positions_1 = [ + [ + PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1), + PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_1), + PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_1), + PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1), + ] + ] + + # Create request with 4 MM items, with 2 identical items + request1 = create_requests( + num_requests=1, + num_tokens=NUM_TOKENS_1, + mm_hashes_list=mm_hashes_list_1, + mm_positions=mm_positions_1, + max_tokens=1, # For simplicity + )[0] + + # Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist + scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True]) + scheduler.ec_connector.update_state_after_alloc = Mock( + wraps=scheduler.ec_connector.update_state_after_alloc + ) + + scheduler.add_request(request1) + output = scheduler.schedule() + + # Should schedule all tokens + scheduled_tokens = output.num_scheduled_tokens[request1.request_id] + assert scheduled_tokens == NUM_TOKENS_1 + + # Encoder cache should contain all mm items from request + _assert_right_encoder_cache_allocated(scheduler, requests=[request1]) + + # Should have called update_state_after_alloc for external load + scheduler.ec_connector.update_state_after_alloc.assert_called() + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # ECConnector should carry metadata for 2nd and 4th mm item + _assert_right_ec_connector_metadata( + output, mm_features_list=[request1.mm_features[1], request1.mm_features[3]] + ) + + # Should schedule ONLY 1 encoder input (index 0), no repeat for identical items + _assert_right_encoder_inputs( + output, + requests=[request1], + expected_encoder_inputs=[[0]], # index 0 of the mm item ONLY + expected_total_reqs=1, + ) + + # Simulate model execution 1 step + model_output = ModelRunnerOutput( + req_ids=[request1.request_id], + req_id_to_index={request1.request_id: 0}, + sampled_token_ids=[[100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # request1 is finished after outputing 1 token + # Finish request + scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED) + + # Create another request with 4 MM items + NUM_TOKENS_2 = 400 + NUM_ENCODER_TOKENS_2 = 50 + mm_hashes_list_2 = [["hash1_C", "hash1_D", "hash1_E", "hash1_A"]] + mm_positions_2 = [ + [ + PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2), + PlaceholderRange(offset=100, length=NUM_ENCODER_TOKENS_2), + PlaceholderRange(offset=200, length=NUM_ENCODER_TOKENS_2), + PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2), + ] + ] + + request2 = create_requests( + num_requests=1, + num_tokens=NUM_TOKENS_2, + mm_hashes_list=mm_hashes_list_2, + mm_positions=mm_positions_2, + max_tokens=1, # For simplicity + )[0] + + # Mock partial cache hit: only hash1_A and hash1_C exist in connector + scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True]) + + scheduler.add_request(request2) + output = scheduler.schedule() + + # Check + # Should schedule all tokens + scheduled_tokens = output.num_scheduled_tokens[request2.request_id] + assert scheduled_tokens == 400 + + # Encoder cache should contain all mm items from request2 + _assert_right_encoder_cache_allocated(scheduler, requests=[request2]) + + # Should call update_state_after_alloc for hash1_C, ONLY + # hash1_A should not be loaded from connector + # since it's computed in last request & exist in local cache + # Order of getting encoder cache should be: local cache -> connector-> compute + scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 0) + scheduler.ec_connector.update_state_after_alloc.assert_called_once() + + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # ECConnector should carry metadata for hash1_C only (index 0) + _assert_right_ec_connector_metadata( + output, mm_features_list=[request2.mm_features[0]] + ) + + # Should schedule 2 encoder input hash1_D and hash1_E (index 1, 2) + _assert_right_encoder_inputs( + output, + requests=[request2], + expected_encoder_inputs=[[1, 2]], + expected_total_reqs=1, + ) + + +@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"]) +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector): + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + max_num_seqs=10, # allow multiple requests + max_num_batched_tokens=2048, + enable_prefix_caching=True, + use_kv_connector=use_kv_connector, + use_ec_connector=True, + ec_role="ec_consumer", + ) + mm_hashes_list = [[f"hash_{i}"] for i in range(10)] + mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)] + requests = create_requests( + num_requests=10, + num_tokens=200, + mm_hashes_list=mm_hashes_list, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + # Set up to test different encoder cache exsistence scenario after preemption + # Order of getting encoder cache should be: local cache -> connector-> compute + scheduler.ec_connector.update_state_after_alloc = Mock( + wraps=scheduler.ec_connector.update_state_after_alloc + ) + + if cache_exist == "local": + # Allocate cache to cache manager manually to mimick + for req in requests: + scheduler.encoder_cache_manager.allocate(req, 0) + else: + # Make sure local encoder cache empty + scheduler.encoder_cache_manager.cached = {} + + if cache_exist == "connector_only": + # Cache exist in ec_connector + scheduler.ec_connector.has_caches = Mock(return_value=[True]) + elif cache_exist == "no_where": + scheduler.ec_connector.has_caches = Mock(return_value=[False]) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert output.scheduled_cached_reqs.num_reqs == 0 + assert len(output.finished_req_ids) == 0 + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + ## Encoder-cache-specific checks: + # mm_hashes of requests exist in cache after scheduling for all scenario + _assert_right_encoder_cache_allocated(scheduler, requests=requests) + + # Should only call update_state_after_alloc when loaded externally + if cache_exist == "connector_only": + scheduler.ec_connector.update_state_after_alloc.assert_called_with( + requests[-1], 0 + ) + + # Concat mm_features for the 10 requests together + mm_features_list = [feature for req in requests for feature in req.mm_features] + + # Check metadata should contain mm data for all 10 requests + _assert_right_ec_connector_metadata(output, mm_features_list=mm_features_list) + else: + scheduler.ec_connector.update_state_after_alloc.assert_not_called() + # ECConnector should carry no metadata + _assert_right_ec_connector_metadata(output, mm_features_list=[]) + + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # Should only schedule encoder input when cache is not found anywhere + if cache_exist == "no_where": + _assert_right_encoder_inputs( + output, + requests=requests, + expected_encoder_inputs=[[0] for _ in range(10)], + expected_total_reqs=10, + ) + else: + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_unable_to_allocate(use_kv_connector): + """ + Test whether scheduler with ECConnector is able to handle + unable to allocate (run out of blocks). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 4 + NUM_BLOCKS = 10 + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + enable_prefix_caching=True, + use_kv_connector=use_kv_connector, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + use_ec_connector=True, + ec_role="ec_consumer", + ) + + # Mock ec_connector load external cache behavior + scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.update_state_after_alloc = Mock( + wraps=scheduler.ec_connector.update_state_after_alloc + ) + + # Create two requests. The second request will not be able to + # allocate slots because it will not have enough blocks. + NUM_REQUESTS = 2 + NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE + MAX_TOKENS = 2 + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + mm_hashes_list=[["hash_1"], ["hash_2"]], + mm_positions=[ + [PlaceholderRange(offset=1, length=10)] for _ in range(NUM_REQUESTS) + ], + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + # Setup MODEL_RUNNER_OUTPUT to be run in _step_until_done later + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + # Just one request should be running. + output = scheduler.schedule() + scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id] + assert scheduled_tokens == NUM_TOKENS + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Should have called update_state_after_alloc for external load + scheduler.ec_connector.update_state_after_alloc.assert_called_with( + scheduler.running[0], 0 + ) + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # All memory should be freed, with one request waiting. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Just one request should be running. + output = scheduler.schedule() + scheduled_tokens = output.num_scheduled_tokens[scheduler.running[0].request_id] + assert scheduled_tokens == NUM_TOKENS + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # update_state_after_alloc should be called for loading external cache + scheduler.ec_connector.update_state_after_alloc.assert_called_with( + scheduler.running[0], 0 + ) + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # All memory should be freed, with no requests waiting / running. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + +@pytest.mark.parametrize("cache_exist", ["local", "connector_only", "no_where"]) +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_priority_scheduling_ec_connector_preemption_and_resumption( + cache_exist, use_kv_connector +): + """Test that priority scheduling preempts lower priority requests + when out of KV cache space.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + model="llava-hf/llava-1.5-7b-hf", + enable_prefix_caching=True, + max_num_seqs=2, # allow multiple requests + # kv connector should not effect test results + use_kv_connector=use_kv_connector, + num_blocks=15, # can hold 244 tokens with 14 blocks (first block is null) + block_size=16, # standard block size + use_ec_connector=True, + ec_role="ec_consumer", + ) + + # Mock cache hit: Both cache exist in connector (at E->PD initially) + scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.update_state_after_alloc = Mock( + wraps=scheduler.ec_connector.update_state_after_alloc + ) + + # Create a request and schedule it (and to be preempted) + request_low = create_requests_with_priority( + num_requests=1, + priorities=[1], + arrival_times=[0.0], + num_tokens=94, + mm_hashes_list=[["hash_low"]], + # NOTE: this test only preempt the last block. + # Setting mm_position at the last block can force to recompute encoding + mm_positions=[[PlaceholderRange(offset=82, length=10)]], + starting_idx=0, + )[0] + scheduler.add_request(request_low) + # 1st schedule + output = scheduler.schedule() + + assert len(output.scheduled_new_reqs) == 1 + scheduled_tokens = output.num_scheduled_tokens[request_low.request_id] + assert scheduled_tokens == 94 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + ## Encoder-cache-specific checks: + # Encoder cache should contain mm items from request + _assert_right_encoder_cache_allocated(scheduler, requests=[request_low]) + + # Verify update_state_after_alloc called (external load) + scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_low, 0) + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # ECConnector should carry metadata of request + _assert_right_ec_connector_metadata( + output, mm_features_list=request_low.mm_features + ) + + # Scheduled encoder input should be empty; no mm to compute + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + # Simulate model execution - 1st decode + model_output = ModelRunnerOutput( + req_ids=[request_low.request_id], + req_id_to_index={request_low.request_id: 0}, + sampled_token_ids=[[100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # Create a high priority request and schedule it + request_high = create_requests_with_priority( + num_requests=1, + priorities=[0], + arrival_times=[1.0], + num_tokens=128, + mm_hashes_list=[["hash_high"]], + mm_positions=[[PlaceholderRange(offset=1, length=10)]], + max_tokens=2, + starting_idx=1, + )[0] + scheduler.add_request(request_high) + # 2nd schedule + output = scheduler.schedule() + + # KV cache should be full at this point + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 2 + + ## Encoder-cache-specific checks: + # Encoder cache should contain mm items from request + _assert_right_encoder_cache_allocated(scheduler, requests=[request_high]) + + # Verify update_state_after_alloc called (external load) + scheduler.ec_connector.update_state_after_alloc.assert_called_with(request_high, 0) + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # ECConnector should carry metadata of request + _assert_right_ec_connector_metadata( + output, mm_features_list=request_high.mm_features + ) + + # Scheduled encoder input should be empty; no mm to compute + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + # Simulate model execution - 2nd decode + requests = [request_low, request_high] + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[100] for _ in requests], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # 3rd schedule - - this should trigger preemption + # req_low needs 96 tokens = 6 blocks + # req_high needs 129 tokens = 9 blocks + # so doesn't fit in 14 blocks. + output = scheduler.schedule() + + # Should have preempted req_low + assert len(output.scheduled_new_reqs) == 0 + assert output.scheduled_cached_reqs.num_reqs == 1 + assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 + + ## Encoder-cache-specific checks: + # request_high is in decode phase now + # ECConnector should carry no metadata + _assert_right_ec_connector_metadata(output, mm_features_list=[]) + + # Scheduled encoder input should be empty; no mm to compute + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + # Simulate model execution - 3rd decode, after req_low was preempted + requests = [request_low, request_high] + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[100], [100, 200]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + # Finish the requests to make room for the preempted requests to resume + # req_high is finished after outputing 2 tokens + scheduler.update_from_output(output, model_output) + scheduler.finish_requests( + request_high.request_id, RequestStatus.FINISHED_LENGTH_CAPPED + ) + + # Set up to test different encoder cache exsistence scenario after preemption + # Order of getting encoder cache should be: local cache -> connector-> compute + # By default, the cache should still exist in local in this test case + if cache_exist != "local": + # Make local encoder cache empty + scheduler.encoder_cache_manager.cached = {} + + if cache_exist == "connector_only": + # Cache exist in ec_connector + scheduler.ec_connector.has_caches = Mock(return_value=[True]) + elif cache_exist == "no_where": + scheduler.ec_connector.has_caches = Mock(return_value=[False]) + + # 4th Schedule - this should trigger req_low resumption from waiting + output = scheduler.schedule() + scheduled_cached_reqs = output.scheduled_cached_reqs + resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption + + assert len(output.scheduled_new_reqs) == 0 + assert scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Preempted request resumed in scheduled_cached_reqs + assert len(resumed_from_preemption) == 1 + assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1 + assert resumed_from_preemption[0] + assert scheduled_cached_reqs.req_ids[0] == request_low.request_id + assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None + ## Resumed tokens include 94 prompt tokens and 2 decoded tokens + assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 96 + assert scheduled_cached_reqs.resumed_req_token_ids[0][95] == 100 + assert scheduler.running[0].request_id == request_low.request_id + assert request_high.request_id in output.finished_req_ids + + ## Encoder-cache-specific checks: + # mm_hash of request_low exists in cache after scheduling for all scenario + _assert_right_encoder_cache_allocated(scheduler, requests=[request_low]) + + # Should only call update_state_after_alloc when loaded externally + if cache_exist == "connector_only": + scheduler.ec_connector.update_state_after_alloc.assert_called_with( + request_low, 0 + ) + _assert_right_ec_connector_metadata( + output, mm_features_list=request_low.mm_features + ) + else: + scheduler.ec_connector.update_state_after_alloc.assert_not_called() + # ECConnector should carry no metadata + _assert_right_ec_connector_metadata(output, mm_features_list=[]) + + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # Should only schedule encoder input when cache is not found anywhere + if cache_exist == "no_where": + _assert_right_encoder_inputs( + output, + requests=[request_low], + expected_encoder_inputs=[[0]], + expected_total_reqs=1, + ) + else: + _assert_right_encoder_inputs(output, expected_total_reqs=0) + + +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connector): + """ + Scenario: + - Encoder cache size: 32 + - Request A: 1 feature (12 tokens) → NOT cached remotely. + - Request B: 3 features (3 x 10 tokens) → ALL cached remotely. + + Steps: + 1. Schedule Request A (locally uses 12 tokens). + 2. Schedule Request B (remote cache) - only schedule 1st and 2nd + 3. Free A's cache, then schedule B again (continuation) - schedule 3rd image + """ + scheduler = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + max_num_batched_tokens=1024, + enable_prefix_caching=True, + use_kv_connector=use_kv_connector, + block_size=16, + num_blocks=11, # Can hold 160 tokens (first block is null) + use_ec_connector=True, + ec_role="ec_consumer", + ) + + # Limit the number of availiable slots of EncoderCacheManager + scheduler.encoder_cache_manager = EncoderCacheManager(cache_size=32) + + # Create MM request1 + NUM_TOKENS_1 = 50 # NOTE: includes mm tokens + NUM_ENCODER_TOKENS_1 = 12 + mm_hashes_list_1 = [["hash1_1"]] + mm_positions_1 = [[PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_1)]] + + request1 = create_requests( + num_requests=1, + num_tokens=NUM_TOKENS_1, + mm_hashes_list=mm_hashes_list_1, + mm_positions=mm_positions_1, + max_tokens=1, # For simplicity + req_ids=["req1"], + )[0] + + # Create MM request1 with 3 MM items + NUM_TOKENS_2 = 40 + NUM_ENCODER_TOKENS_2 = 10 + mm_hashes_list_2 = [["hash2_1", "hash2_2", "hash2_3"]] + mm_positions_2 = [ + [ + PlaceholderRange(offset=0, length=NUM_ENCODER_TOKENS_2), + PlaceholderRange(offset=12, length=NUM_ENCODER_TOKENS_2), + PlaceholderRange(offset=24, length=NUM_ENCODER_TOKENS_2), + ] + ] + + request2 = create_requests( + num_requests=1, + num_tokens=NUM_TOKENS_2, + mm_hashes_list=mm_hashes_list_2, + mm_positions=mm_positions_2, + max_tokens=10, + req_ids=["req2"], + )[0] + + # Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely + scheduler.ec_connector.has_caches = Mock( + side_effect=lambda req: [True, True, True] if req == request2 else [False] + ) + scheduler.ec_connector.update_state_after_alloc = Mock( + wraps=scheduler.ec_connector.update_state_after_alloc + ) + + scheduler.add_request(request1) + scheduler.add_request(request2) + output = scheduler.schedule() + + # Now, since encoder cache manager can only store 32 tokens + # It should allocated mm item hash1_1, hash2_1 and hash2_2 + scheduled_tokens = output.num_scheduled_tokens[request1.request_id] + assert scheduled_tokens == NUM_TOKENS_1 + assert scheduler.get_num_unfinished_requests() == 2 + + # Encoder cache should contain mm item from request1 + _assert_right_encoder_cache_allocated( + scheduler, hashes_to_check=["hash1_1", "hash2_1", "hash2_2"] + ) + + # request2's 2nd mm item is the last call of update_state_after_alloc + scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 1) + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # ECConnector should carry metadata of hash2_1 and hash2_2 ONLY + _assert_right_ec_connector_metadata( + output, mm_features_list=[request2.mm_features[0], request2.mm_features[1]] + ) + + # Should schedule ONLY 1 encoder input + _assert_right_encoder_inputs( + output, + requests=[request1], + expected_encoder_inputs=[[0]], # index 0 of the mm item of request1 + expected_total_reqs=1, + ) + + # Simulate model execution 1 step + model_output = ModelRunnerOutput( + req_ids=[request1.request_id, request2.request_id], + req_id_to_index={request1.request_id: 0, request2.request_id: 1}, + sampled_token_ids=[[100], [121]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # request1 is finished after outputing 1 token + # Finish request + scheduler.finish_requests(request1.request_id, RequestStatus.FINISHED_LENGTH_CAPPED) + assert scheduler.get_num_unfinished_requests() == 1 + + # Schedule again; Now request1's encoder cache should be freed + # -> hash2_3 can be scheduled and allocated + output = scheduler.schedule() + + # Check + # Should schedule all tokens + scheduled_tokens = output.num_scheduled_tokens[request2.request_id] + print(f"Hero: scheduled_tokens for req2: {scheduled_tokens}") + print(f"hero: num_scheduled_tokens 2: {output.num_scheduled_tokens}") + + # Encoder cache should contain all mm items from request2 + _assert_right_encoder_cache_allocated(scheduler, requests=[request2]) + + # request2's 3rd mm item is the ONLY call of update_state_after_alloc + scheduler.ec_connector.update_state_after_alloc.assert_called_with(request2, 2) + scheduler.ec_connector.update_state_after_alloc.assert_called_once() + + scheduler.ec_connector.update_state_after_alloc.reset_mock() + + # ECConnector should carry metadata for hash2_3 ONLY + _assert_right_ec_connector_metadata( + output, mm_features_list=[request2.mm_features[2]] + ) + + # Should schedule no encoder input + _assert_right_encoder_inputs( + output, + expected_total_reqs=0, + ) + + +# ============================================================================== +# EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests end +# ============================================================================== diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 6e739d6b0e77a..3692e633322e2 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -5,6 +5,7 @@ import torch from vllm.config import ( CacheConfig, + ECTransferConfig, KVTransferConfig, ModelConfig, SchedulerConfig, @@ -46,6 +47,8 @@ def create_scheduler( num_speculative_tokens: int | None = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, + use_ec_connector: bool = False, + ec_role: str | None = None, ) -> Scheduler | AsyncScheduler: """Create scheduler under test. @@ -107,12 +110,23 @@ def create_scheduler( model="ngram", num_speculative_tokens=num_speculative_tokens ) + ec_transfer_config = ( + ECTransferConfig( + ec_connector="ECSharedStorageConnector", + ec_role=ec_role, + ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, + ) + if use_ec_connector + else None + ) + vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, + ec_transfer_config=ec_transfer_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests @@ -140,12 +154,14 @@ _none_hash_initialized = False def create_requests( num_requests: int, num_tokens: int = 10, + mm_hashes_list: list[list[str]] | None = None, mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, stop_token_ids: list[int] | None = None, prompt_logprobs: int | None = None, same_prompt: bool = False, block_size: int = 16, + req_ids: list[str] | None = None, ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: @@ -160,25 +176,58 @@ def create_requests( prompt_logprobs=prompt_logprobs, ) requests = [] + + if mm_hashes_list is not None: + # NOTE: allow manual input; some mm items can have the same identifier + # no. of mm_hashes and mm_positions for each request should be identical + assert mm_positions is not None, ( + "mm_positions must be provided when mm_hashes_list is provided" + ) + assert len(mm_hashes_list) == len(mm_positions) == num_requests + assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions] + + # Since same identifier would imply they are identical encoder output + # Verify mm items with identical identifier are having mm_position.length + seen_hashes: dict[str, int] = {} + + if req_ids: + assert len(req_ids) == num_requests + else: + req_ids = [f"{i}" for i in range(num_requests)] + for i in range(num_requests): mm_features = [] - if mm_positions is not None: - mm_position = mm_positions[i] - for j, position in enumerate(mm_position): - # Dummy hash for each mm item should be unique - # since encoder cache tracks entries by hash + + for j, position in enumerate( + mm_positions[i] if mm_positions is not None else [] + ): + if mm_hashes_list is not None: + identifier = mm_hashes_list[i][j] + + # Verify if position length is identical + position_length = position.length + if identifier in seen_hashes: + assert seen_hashes[identifier] == position_length, ( + f"mm_hash '{identifier}' has inconsistent position lengths: " + f"previously {seen_hashes[identifier]}, now {position_length} " + f"at request {i}, position {j}" + ) + else: + seen_hashes[identifier] = position_length + else: + # Unique dummy hash for each mm item identifier = f"hash{i}_{j}" - mm_feature = MultiModalFeatureSpec( - data=MultiModalKwargsItem.dummy("dummy_m"), - mm_position=position, - identifier=identifier, - modality="image", - ) - mm_features.append(mm_feature) + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image", + ) + mm_features.append(mm_feature) prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens request = Request( - request_id=f"{i}", + request_id=req_ids[i], prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, diff --git a/tests/v1/ec_connector/integration/README.md b/tests/v1/ec_connector/integration/README.md new file mode 100644 index 0000000000000..30426e055ade8 --- /dev/null +++ b/tests/v1/ec_connector/integration/README.md @@ -0,0 +1,171 @@ +# EPD Correctness Test + +This test verifies that EPD (Encoder-Prefill-Decode) disaggregation produces identical outputs to a baseline single instance. + +## What It Tests + +- **Baseline**: Single vLLM instance serving a multimodal model +- **EPD (1E+1PD)**: 1 Encoder + 1 Prefill-Decode instance +- **Baseline (1P+1D)**: 1 Prefill + 1 Decode instance +- **EPD (1E+1P+1D)**: 1 Encoder + 1 Prefill + 1 Decode instance + +The test ensures that disaggregated encoding produces **identical** outputs to the baseline. + +Note that currently PD disaggregation set up may give slightly different results from a single instance. Therefore, we need the result from 1P+1D as the baseline for 1E+1P+1D + +Please refer to [Disaggregated Encoder Feature](../../../docs/features/disagg_encoder.md) for the detailed explanation for the EPD features. + +## Files + +- `run_epd_correctness_test.sh` - Main test script (starts all instances and runs tests) +- `test_epd_correctness.py` - Python test script (compares outputs) + +## Usage + +### Multimodal Prompts (Default) + +```bash +cd vllm +./tests/v1/ec_connector/integration/run_epd_correctness_test.sh +``` + +This runs the test with actual multimodal (image) prompts. + +### Text-Only Prompts + +```bash +cd vllm +USE_MM_PROMPTS=0 ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh +``` + +This runs a quick test with text-only prompts to verify the setup works. + +### Custom Configuration + +```bash +# Use specific GPUs +GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh + +# Use specific ports +ENDPOINT_PORT=10001 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh + +# Use specific model +MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh + +# Use specific storage path +EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh +``` + +## How It Works + +### Step 1: Baseline + +1. Start single vLLM instance on GPU +2. Run test prompts (multimodal or text-only) +3. Save outputs to `.vllm_epd_baseline.txt` +4. Shutdown instance + +### Step 2: EPD (1E + 1PD) + +1. Clear encoder cache storage +2. Start instances and proxy +3. Run same test prompts +4. Assert outputs match baseline exactly +5. Shutdown instances + +### Step 3: EPD (1E + 1P + 1D) + +1. Clear encoder cache storage +2. Start instances and proxy +3. Run same test prompts +4. Assert outputs match baseline exactly +5. Shutdown instances + +## Test Scenarios + +### Multimodal Prompts (--use_mm_prompts) + +Tests encoder cache transfer: + +- Single image query +- Multiple images in one request +- Mixed image and text +- Image with detailed questions + +### Text-Only Prompts (default) + +Quick sanity check: + +- Simple text queries +- Text-only explanations +- Verifies proxy routing works + +## Expected Behavior + +### ✅ Test Passes When + +- All disagg outputs match baseline outputs exactly +- No errors during instance startup +- Encoder cache is properly saved and loaded +- Proxy correctly routes requests + +### ❌ Test Fails When + +- Outputs differ between baseline and disagg +- Server startup fails +- Encoder cache not found (should fallback to local execution) +- Proxy routing errors + +## Notes + +- The test uses deterministic generation (`temperature=0.0`, `seed=42`) +- Encoder cache should enable exact output reproduction +- Test cleans up all instances and cache files after completion +- Safe to run multiple times (idempotent) +- We setup the PD disagg part with NixlConnector. Please read details about EPD in `examples/online_serving/disaggregated_encoder/README.md` + +## Requirements + +- Multiple GPUs (3 for 1E+1P+1D, 2 for 1E+1PD, 1 for baseline) + - 1E+1P+1D is runnable with 2 GPU by assign E and P on the same GPU now. +- Multimodal model (e.g., Qwen2.5-VL-3B-Instruct) +- Internet access (for accessing vllm test images) + +## Debugging + +### Check Logs + +Logs and baseline output are saved in `/tmp/` by default. +Can be customized by changing the environment variables. + +### Check Encoder Cache + +```bash +# Verify cache files are created +ls -la $EC_SHARED_STORAGE_PATH/ + +# Should see directories with mm_hash names +# Each containing encoder_cache.safetensors +``` + +### Manual Testing + +Run individual components: + +```bash +# Baseline only +python test_epd_correctness.py \ + --service_url http://localhost:8000 \ + --model_name Qwen/Qwen2.5-VL-3B-Instruct \ + --mode baseline \ + --baseline_file test_output.txt \ + --use_mm_prompts + +# Disagg only (requires baseline output file!) +python test_epd_correctness.py \ + --service_url http://localhost:8000 \ + --model_name Qwen/Qwen2.5-VL-3B-Instruct \ + --mode disagg \ + --baseline_file test_output.txt \ + --use_mm_prompts +``` diff --git a/tests/v1/ec_connector/integration/hato.jpg b/tests/v1/ec_connector/integration/hato.jpg new file mode 100644 index 0000000000000..9c7e390e7d7f6 Binary files /dev/null and b/tests/v1/ec_connector/integration/hato.jpg differ diff --git a/tests/v1/ec_connector/integration/run_epd_correctness_test.sh b/tests/v1/ec_connector/integration/run_epd_correctness_test.sh new file mode 100644 index 0000000000000..55dd39c0a957f --- /dev/null +++ b/tests/v1/ec_connector/integration/run_epd_correctness_test.sh @@ -0,0 +1,476 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# EPD (Encoder-Prefill-Decode) Correctness Test +# +# This script tests that EPD disaggregation produces the same outputs as baseline. +# It runs: +# 1. Baseline: Single vLLM instance +# 2. EPD: 1E + 1PD setup +# 3. Baseline for (E + P + D): 1P + 1D vLLM instances disagg +# 4. EPD: 1E + 1P + 1D setup + +# For GPU usage + +# set -xe + +# Find the git repository root directory +GIT_ROOT=$(git rev-parse --show-toplevel) + +# Model to test +MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}" + +# Set 1 to use multimodal prompts; else to use text-only +USE_MM_PROMPTS="${USE_MM_PROMPTS:-1}" +MM_FLAG="" +if [ $USE_MM_PROMPTS = "1" ]; then + MM_FLAG="--use_mm_prompts" +fi + +# GPU configuration +GPU_E="${GPU_E:-0}" +GPU_P="${GPU_P:-1}" +GPU_D="${GPU_D:-2}" +GPU_SINGLE="${GPU_SINGLE:-$GPU_P}" +GPU_PD="${GPU_PD:-$GPU_P}" + +# Port +ENCODE_PORT="${ENCODE_PORT:-19534}" +PREFILL_PORT="${PREFILL_PORT:-19535}" +DECODE_PORT="${DECODE_PORT:-19536}" +PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19537}" +ENDPOINT_PORT="${ENDPOINT_PORT:-10001}" + +# Storage path for encoder cache +EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache_test}" +TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-600}" + +# Output file for baseline comparison and logs +LOG_PATH="${LOG_PATH:-/tmp}" +BASELINE_FILE="${BASELINE_FILE:-/tmp/vllm_baseline.txt}" +BASELINE_PD_FILE="${BASELINE_PD_FILE:-/tmp/vllm_epd_baseline.txt}" + +mkdir -p $LOG_PATH + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Wait for server to be ready +wait_for_server() { + local port=$1 + timeout "$TIMEOUT_SECONDS" bash -c " + until curl -s localhost:${port}/v1/chat/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup_instances() { + echo "Cleaning up any running vLLM instances..." + pkill -f "vllm serve" || true + pkill -f "disagg_epd_proxy.py" || true + sleep 2 +} + +# Function to run baseline (single instance) +run_baseline() { + echo "================================" + echo "Running BASELINE (single instance)" + echo "================================" + + cleanup_instances + rm -rf "$EC_SHARED_STORAGE_PATH" + + local PORT=$ENDPOINT_PORT + + # Start baseline instance + echo "Starting baseline instance on GPU $GPU_SINGLE, port $PORT" + CUDA_VISIBLE_DEVICES="$GPU_SINGLE" vllm serve "$MODEL" \ + --port $PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.7 \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + > $LOG_PATH/baseline.log 2>&1 & + + local BASELINE_PID=$! + + # Wait for baseline to start + echo "Waiting for baseline instance to start..." + wait_for_server $PORT + + curl http://127.0.0.1:$PORT/v1/models + echo "" + + # Run test in baseline mode + echo "Running baseline..." + + python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \ + --service_url "http://localhost:$PORT" \ + --model_name "$MODEL" \ + --mode baseline \ + --baseline_file "$BASELINE_FILE" \ + $MM_FLAG + + # Cleanup baseline + echo "Stopping baseline instance..." + kill $BASELINE_PID 2>/dev/null || true + sleep 2 + cleanup_instances +} + +# Function to run EPD with 1E + 1PD +run_epd_1e_1pd() { + echo "================================" + echo "Running EPD (1E + 1PD)" + echo "================================" + + cleanup_instances + rm -rf "$EC_SHARED_STORAGE_PATH" + mkdir -p "$EC_SHARED_STORAGE_PATH" + + local ENCODE_PORT=$ENCODE_PORT + local PREFILL_DECODE_PORT=$PREFILL_DECODE_PORT + local PROXY_PORT=$ENDPOINT_PORT + + declare -a PIDS=() + + # Start encoder instance + echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT" + CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --port $ENCODE_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.01 \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 114688 \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + > $LOG_PATH/1e1pd_encoder.log 2>&1 & + PIDS+=($!) + + # Start prefill+decode instance + echo "Starting PD instance on GPU $GPU_PD, port $PREFILL_DECODE_PORT" + CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ + --port $PREFILL_DECODE_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.7 \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + > $LOG_PATH/1e1pd_pd.log 2>&1 & + PIDS+=($!) + + # Wait for instances to start + echo "Waiting for encoder instance..." + wait_for_server $ENCODE_PORT + echo "Waiting for PD instance..." + wait_for_server $PREFILL_DECODE_PORT + + # Start proxy + echo "Starting EPD proxy on port $PROXY_PORT" + python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \ + --host "0.0.0.0" \ + --port $PROXY_PORT \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "disable" \ + --decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \ + > $LOG_PATH/1e1pd_proxy.log 2>&1 & + PIDS+=($!) + + # Wait for proxy + echo "Waiting for proxy..." + wait_for_server $PROXY_PORT + + curl http://127.0.0.1:$PROXY_PORT/v1/models + curl http://127.0.0.1:$PROXY_PORT/health + echo "" + + echo "All EPD (1E+1PD) services are up!" + + # Run test in disagg mode + echo "Running EPD (1E+1PD) correctness test..." + + python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \ + --service_url "http://localhost:$PROXY_PORT" \ + --model_name "$MODEL" \ + --mode disagg \ + --baseline_file "$BASELINE_FILE" \ + $MM_FLAG + + # Cleanup + echo "✓✓ 1E+1PD Correctness Test finished" + echo "Stopping EPD (1E+1PD) instances..." + for pid in "${PIDS[@]}"; do + kill $pid 2>/dev/null || true + done + sleep 2 + cleanup_instances +} + +# Function to run baseline for 1E + 1P + 1D (PD disagg) +run_baseline_1p_1d() { + echo "================================" + echo "Running PD BASELINE (1P + 1D)" + echo "================================" + + cleanup_instances + rm -rf "$EC_SHARED_STORAGE_PATH" + mkdir -p "$EC_SHARED_STORAGE_PATH" + + local PREFILL_PORT=$PREFILL_PORT + local DECODE_PORT=$DECODE_PORT + local PROXY_PORT=$ENDPOINT_PORT + + declare -a PIDS=() + + # Start prefill instance + echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT" + CUDA_VISIBLE_DEVICES="$GPU_P" \ + VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \ + vllm serve "$MODEL" \ + --port $PREFILL_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.7 \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_producer" + }' \ + > $LOG_PATH/1p1d_prefill.log 2>&1 & + PIDS+=($!) + + # Start decode instance + echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT" + CUDA_VISIBLE_DEVICES="$GPU_D" \ + VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \ + vllm serve "$MODEL" \ + --port $DECODE_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.7 \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer" + }' \ + > $LOG_PATH/1p1d_decode.log 2>&1 & + PIDS+=($!) + + # Wait for instances to start + echo "Waiting for prefill instance..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance..." + wait_for_server $DECODE_PORT + + # Start proxy + echo "Starting EPD proxy on port $PROXY_PORT" + python "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \ + --host "0.0.0.0" \ + --port $PROXY_PORT \ + --prefiller-ports $PREFILL_PORT \ + --decoder-ports $DECODE_PORT \ + > $LOG_PATH/1p1d_proxy.log 2>&1 & + PIDS+=($!) + + # Wait for proxy + echo "Waiting for proxy..." + wait_for_server $PROXY_PORT + + curl http://127.0.0.1:$PROXY_PORT/healthcheck + echo "" + + echo "All PD (1P+1D) services are up!" + + # Run test in baseline mode + echo "Running PD disagg baseline..." + + python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \ + --service_url "http://localhost:$PROXY_PORT" \ + --model_name "$MODEL" \ + --mode baseline_pd \ + --baseline_file "$BASELINE_PD_FILE" \ + $MM_FLAG + + # Cleanup + echo "Stopping PD (1P+1D) instances..." + for pid in "${PIDS[@]}"; do + kill $pid 2>/dev/null || true + done + sleep 2 + cleanup_instances +} + +# Function to run EPD with 1E + 1P + 1D +run_epd_1e_1p_1d() { + echo "================================" + echo "Running EPD (1E + 1P + 1D)" + echo "================================" + + cleanup_instances + rm -rf "$EC_SHARED_STORAGE_PATH" + mkdir -p "$EC_SHARED_STORAGE_PATH" + + local ENCODE_PORT=$ENCODE_PORT + local PREFILL_PORT=$PREFILL_PORT + local DECODE_PORT=$DECODE_PORT + local PROXY_PORT=$ENDPOINT_PORT + + declare -a PIDS=() + + # Start encoder instance + echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT" + CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ + --port $ENCODE_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.01 \ + --enable-request-id-headers \ + --no-enable-prefix-caching \ + --max-num-batched-tokens 114688 \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + > $LOG_PATH/1e1p1d_encoder.log 2>&1 & + PIDS+=($!) + + # Start prefill instance + echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT" + CUDA_VISIBLE_DEVICES="$GPU_P" \ + VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \ + vllm serve "$MODEL" \ + --port $PREFILL_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.7 \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" + } + }' \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_producer" + }' \ + > $LOG_PATH/1e1p1d_prefill.log 2>&1 & + PIDS+=($!) + + # Start decode instance + echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT" + CUDA_VISIBLE_DEVICES="$GPU_D" \ + VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \ + vllm serve "$MODEL" \ + --port $DECODE_PORT \ + --enforce-eager \ + --gpu-memory-utilization 0.7 \ + --enable-request-id-headers \ + --max-num-seqs 128 \ + --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer" + }' \ + > $LOG_PATH/1e1p1d_decode.log 2>&1 & + PIDS+=($!) + + # Wait for instances to start + echo "Waiting for encoder instance..." + wait_for_server $ENCODE_PORT + echo "Waiting for prefill instance..." + wait_for_server $PREFILL_PORT + echo "Waiting for decode instance..." + wait_for_server $DECODE_PORT + + # Start proxy + echo "Starting EPD proxy on port $PROXY_PORT" + python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \ + --host "0.0.0.0" \ + --port $PROXY_PORT \ + --encode-servers-urls "http://localhost:$ENCODE_PORT" \ + --prefill-servers-urls "http://localhost:$PREFILL_PORT" \ + --decode-servers-urls "http://localhost:$DECODE_PORT" \ + > $LOG_PATH/1e1p1d_proxy.log 2>&1 & + PIDS+=($!) + + # Wait for proxy + echo "Waiting for proxy..." + wait_for_server $PROXY_PORT + + curl http://127.0.0.1:$PROXY_PORT/v1/models + curl http://127.0.0.1:$PROXY_PORT/health + echo "" + + echo "All EPD (1E+1P+1D) services are up!" + + # Run test in disagg mode + echo "Running EPD (1E+1P+1D) correctness test..." + + python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \ + --service_url "http://localhost:$PROXY_PORT" \ + --model_name "$MODEL" \ + --mode disagg \ + --baseline_file "$BASELINE_PD_FILE" \ + $MM_FLAG + + # Cleanup + echo "✓✓ 1E+1P+1D Correctness Test finished" + echo "Stopping EPD (1E+1P+1D) instances..." + for pid in "${PIDS[@]}"; do + kill $pid 2>/dev/null || true + done + sleep 2 + cleanup_instances +} + +# Main execution +echo "================================" +echo "EPD Correctness Test Suite" +echo "Model: $MODEL" +echo "================================" + +# Step 1: Run baseline +run_baseline + +# Step 2: Test 1E + 1PD +run_epd_1e_1pd + +# Step 3: Test baseline 1P + 1D +run_baseline_1p_1d + +# Step 4: Test 1E + 1P + 1D +run_epd_1e_1p_1d + +# Cleanup output file +rm -f "$BASELINE_FILE" +rm -f "$BASELINE_PD_FILE" + +echo "================================" +echo "✓✓ All EPD correctness tests finished!" +echo "================================" diff --git a/tests/v1/ec_connector/integration/test_epd_correctness.py b/tests/v1/ec_connector/integration/test_epd_correctness.py new file mode 100644 index 0000000000000..69c4c58e349b9 --- /dev/null +++ b/tests/v1/ec_connector/integration/test_epd_correctness.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +EPD Correctness Test + +Tests that EPD (Encoder-Prefill-Decode) disaggregation produces the same +outputs as a baseline single instance. + +Usage: + # Baseline mode (saves outputs): + python test_epd_correctness.py \ + --service_url http://localhost:8000 \ + --model_name Qwen/Qwen2.5-VL-3B-Instruct \ + --mode baseline \ + --baseline_file .vllm_epd_baseline.txt + + # Disagg mode (compares outputs): + python test_epd_correctness.py \ + --service_url http://localhost:8000 \ + --model_name Qwen/Qwen2.5-VL-3B-Instruct \ + --mode disagg \ + --baseline_file .vllm_epd_baseline.txt +""" + +import argparse +import json +import os +import time + +import openai +import requests + +from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import encode_image_base64 + +MAX_OUTPUT_LEN = 256 + +# Sample prompts with multimodal content +image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720)) +image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720)) + +image_local_path = f"{os.path.dirname(os.path.abspath(__file__))}/hato.jpg" + +SAMPLE_PROMPTS_MM: list[dict] = [ + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image;base64,{encode_image_base64(image_1)}" + }, + }, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + "description": "Single image query", + }, + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image;base64,{encode_image_base64(image_2)}" + }, + }, + { + "type": "image_url", + "image_url": {"url": f"file://{image_local_path}"}, + }, + {"type": "text", "text": "Describe these 2 images in detail."}, + ], + } + ], + "description": "2 images with detailed query", + }, +] + +# Text-only prompts for mixed testing +SAMPLE_PROMPTS_TEXT: list[dict] = [ + { + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "description": "Simple text-only query", + }, + { + "messages": [ + {"role": "user", "content": "Explain quantum computing in simple terms."} + ], + "description": "Text-only explanation request", + }, +] + + +def check_vllm_server(url: str, timeout=5, retries=10) -> bool: + """Check if the vLLM server is ready. + + Args: + url: The URL to check (usually /health or /healthcheck endpoint) + timeout: Timeout in seconds for each request + retries: Number of retries if the server is not ready + + Returns: + True if the server is ready, False otherwise + """ + for attempt in range(retries): + try: + response = requests.get(url, timeout=timeout) + if response.status_code == 200: + print(f"Server is ready at {url}") + return True + else: + print( + f"Attempt {attempt + 1}/{retries}: Server returned " + f"status code {response.status_code}" + ) + except requests.exceptions.RequestException as e: + print(f"Attempt {attempt + 1}/{retries}: Error connecting: {e}") + time.sleep(2) # Wait before retrying + return False + + +def run_chat_completion( + base_url: str, + model_name: str, + messages: list, + max_tokens: int = MAX_OUTPUT_LEN, +) -> str: + """Run a chat completion request. + + Args: + base_url: Base URL of the vLLM server + model_name: Name of the model + messages: Messages for chat completion + max_tokens: Maximum tokens to generate + + Returns: + Generated text content + """ + client = openai.OpenAI(api_key="EMPTY", base_url=base_url) + + completion = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + ) + + return completion.choices[0].message.content + + +def main(): + """Main test function.""" + parser = argparse.ArgumentParser( + description="EPD correctness test - compare disagg vs baseline" + ) + + parser.add_argument( + "--service_url", + type=str, + required=True, + help="The vLLM service URL (e.g., http://localhost:8000)", + ) + + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Model name", + ) + + parser.add_argument( + "--mode", + type=str, + default="baseline", + choices=["baseline", "baseline_pd", "disagg"], + help="Mode: baseline/baseline_pd (saves outputs) or disagg (compares outputs)", + ) + + parser.add_argument( + "--baseline_file", + type=str, + default=".vllm_epd_baseline.txt", + help="File to save/load baseline outputs", + ) + + parser.add_argument( + "--use_mm_prompts", + action="store_true", + help="Use multimodal prompts (default: use text-only for quick testing)", + ) + + args = parser.parse_args() + + print(f"Service URL: {args.service_url}") + print(f"Model: {args.model_name}") + print(f"Mode: {args.mode}") + print(f"Output file: {args.baseline_file}") + print(f"Use MM prompts: {args.use_mm_prompts}") + + # Determine health check endpoint + if args.mode == "baseline": + health_check_url = f"{args.service_url}/health" + elif args.mode == "baseline_pd": + # Nixl toy proxy use /healthcheck + health_check_url = f"{args.service_url}/healthcheck" + else: + # Disagg EPD proxy uses /health + health_check_url = f"{args.service_url}/health" + if not os.path.exists(args.baseline_file): + raise ValueError( + f"In disagg mode, the output file {args.baseline_file} from " + "baseline does not exist. Run baseline mode first." + ) + + # Check if server is ready + if not check_vllm_server(health_check_url): + raise RuntimeError(f"vLLM server at {args.service_url} is not ready!") + + # Select prompts to use + if args.use_mm_prompts: + test_prompts = SAMPLE_PROMPTS_MM + print("Using multimodal prompts") + else: + test_prompts = SAMPLE_PROMPTS_TEXT + print("Using text-only prompts for quick testing") + + # Run completions + service_url = f"{args.service_url}/v1" + output_strs = {} + + for i, prompt_data in enumerate(test_prompts): + print( + f"\nRunning prompt {i + 1}/{len(test_prompts)}: { + prompt_data['description'] + }" + ) + + output_str = run_chat_completion( + base_url=service_url, + model_name=args.model_name, + messages=prompt_data["messages"], + max_tokens=MAX_OUTPUT_LEN, + ) + + # Use description as key for comparison + key = prompt_data["description"] + output_strs[key] = output_str + print(f"Output: {output_str}") + + if args.mode in ("baseline", "baseline_pd"): + # Baseline mode: Save outputs + print(f"\nSaving baseline outputs to {args.baseline_file}") + try: + with open(args.baseline_file, "w") as json_file: + json.dump(output_strs, json_file, indent=4) + print("✅ Baseline outputs saved successfully") + except OSError as e: + print(f"Error writing to file: {e}") + raise + else: + # Disagg mode: Load and compare outputs + print(f"\nLoading baseline outputs from {args.baseline_file}") + baseline_outputs = None + try: + with open(args.baseline_file) as json_file: + baseline_outputs = json.load(json_file) + except OSError as e: + print(f"Error reading from file: {e}") + raise + + # Verify outputs match + print("\nComparing disagg outputs with baseline...") + assert isinstance(baseline_outputs, dict), "Baseline outputs should be a dict" + assert len(baseline_outputs) == len(output_strs), ( + f"Length mismatch: baseline has {len(baseline_outputs)}, " + f"disagg has {len(output_strs)}" + ) + + all_match = True + for key, baseline_output in baseline_outputs.items(): + assert key in output_strs, f"{key} not in disagg outputs" + + disagg_output = output_strs[key] + if baseline_output == disagg_output: + print(f"✅ {key}: MATCH") + else: + print(f"❌ {key}: MISMATCH") + print(f" Baseline: {baseline_output}") + print(f" Disagg: {disagg_output}") + all_match = False + + assert all_match, "❌❌Disagg outputs do not match baseline!❌❌" + if all_match: + print("\n✅ All outputs match! Test PASSED") + + +if __name__ == "__main__": + main() diff --git a/tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py b/tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py new file mode 100644 index 0000000000000..a58daa2628e21 --- /dev/null +++ b/tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for ECSharedStorageConnector. +""" + +import os +from unittest.mock import Mock, patch + +import pytest +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole +from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( + ECSharedStorageConnector, + ECSharedStorageConnectorMetadata, + MMMeta, +) +from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange +from vllm.v1.core.sched.output import SchedulerOutput + + +# ------------------ Mock Classes ------------------ # +class MockRequest: + def __init__(self, request_id, mm_hashes: list[str], token_counts: list[int]): + assert len(mm_hashes) == len(token_counts) + self.request_id = request_id + self._token_counts = token_counts + self.mm_features = [] + for i, mm_hash in enumerate(mm_hashes): + feature = MultiModalFeatureSpec( + data=None, + modality="image", + identifier=mm_hash, + mm_position=PlaceholderRange(offset=0, length=self._token_counts[i]), + ) + self.mm_features.append(feature) + + def get_num_encoder_tokens(self, input_id: int) -> int: + assert input_id < len(self._token_counts) + return self._token_counts[input_id] + + +@pytest.fixture +def temp_storage(tmp_path): + """Fixture providing temporary storage path.""" + return str(tmp_path) + + +@pytest.fixture +def mock_vllm_config_producer(temp_storage): + """Fixture providing mock VllmConfig for producer role.""" + config = Mock(spec=VllmConfig) + config.ec_transfer_config = Mock() + config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage) + config.ec_transfer_config.is_ec_producer = True + return config + + +@pytest.fixture +def mock_vllm_config_consumer(temp_storage): + """Fixture providing mock VllmConfig for consumer role.""" + config = Mock(spec=VllmConfig) + config.ec_transfer_config = Mock() + config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage) + config.ec_transfer_config.is_ec_producer = False + return config + + +@pytest.fixture +def mock_request_with_3_mm(): + """Fixture providing mock Request with 3 multimodal items.""" + request_id = "test_req_123" + mm_hashes = ["img_hash_1", "img_hash_2", "img_hash_3"] + token_counts = [100, 150, 200] + + request = MockRequest(request_id, mm_hashes, token_counts) + return request + + +# ------------------ Unit Tests ------------------ # +class TestECSharedStorageConnectorBasics: + """Test basic EC connector functionality.""" + + def test_initialization_producer(self, mock_vllm_config_producer, temp_storage): + """Test connector initializes correctly as producer.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + assert connector.role == ECConnectorRole.SCHEDULER + assert connector.is_producer + assert connector._storage_path == temp_storage + assert connector._mm_datas_need_loads == {} + + def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage): + """Test connector initializes correctly as consumer.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + assert connector.role == ECConnectorRole.WORKER + assert not connector.is_producer + assert connector._storage_path == temp_storage + + def test_role_assignment(self, mock_vllm_config_producer): + """Test role is correctly assigned.""" + scheduler_connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + worker_connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + assert scheduler_connector.role == ECConnectorRole.SCHEDULER + assert worker_connector.role == ECConnectorRole.WORKER + + +class TestCacheExistence: + """Test cache existence checking using has_caches() API.""" + + def test_has_caches_all_exist_3_items( + self, + mock_vllm_config_producer, + mock_vllm_config_consumer, + mock_request_with_3_mm, + ): + """Test has_caches returns True when all 3 caches exist.""" + # Test for producer first + producer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + # Create cache files using save_caches (proper way) + encoder_cache: dict[str, torch.Tensor] = {} + + for mm_feature in mock_request_with_3_mm.mm_features: + mm_hash = mm_feature.identifier + encoder_cache[mm_hash] = torch.randn(10, 768) + producer.save_caches(encoder_cache, mm_hash) + + # Test using has_caches API + producer_result = producer.has_caches(mock_request_with_3_mm) + + # Assert + assert len(producer_result) == 3 + assert all(producer_result), f"Expected all True, got {producer_result}" + + # Also test consumer can check if cache exists + consumer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.SCHEDULER, + ) + + # Test using has_caches API + consumer_result = consumer.has_caches(mock_request_with_3_mm) + + # Assert + assert len(consumer_result) == 3 + assert all(consumer_result), f"Expected all True, got {consumer_result}" + + def test_has_caches_none_exist( + self, mock_vllm_config_producer, mock_request_with_3_mm + ): + """Test has_caches returns False when no caches exist.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + # Test without creating any files + result = connector.has_caches(mock_request_with_3_mm) + + # Assert + assert len(result) == 3 + assert not any(result), f"Expected all False, got {result}" + + def test_has_caches_partial_exist( + self, mock_vllm_config_producer, mock_request_with_3_mm + ): + """Test has_caches with some caches existing (1 of 3).""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + # Create only the second cache file + mm_hash_second = mock_request_with_3_mm.mm_features[1].identifier + encoder_cache = {mm_hash_second: torch.randn(10, 768)} + connector.save_caches(encoder_cache, mm_hash_second) + + # Test + result = connector.has_caches(mock_request_with_3_mm) + + # Assert + assert len(result) == 3 + assert not result[0] # First doesn't exist + assert result[1] # Second exists + assert not result[2] # Third doesn't exist + + +class TestStateManagement: + """Test connector state management.""" + + def test_update_state_after_alloc_3_items( + self, mock_vllm_config_producer, mock_request_with_3_mm + ): + """Test state update after allocation for 3 MM items.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + # Initial state should be empty + assert len(connector._mm_datas_need_loads) == 0 + + # Update state for all 3 items + for i in range(3): + connector.update_state_after_alloc(mock_request_with_3_mm, index=i) + + # Check state updated for all 3 + assert len(connector._mm_datas_need_loads) == 3 + assert "img_hash_1" in connector._mm_datas_need_loads + assert "img_hash_2" in connector._mm_datas_need_loads + assert "img_hash_3" in connector._mm_datas_need_loads + assert connector._mm_datas_need_loads["img_hash_1"] == 100 + assert connector._mm_datas_need_loads["img_hash_2"] == 150 + assert connector._mm_datas_need_loads["img_hash_3"] == 200 + + def test_build_connector_meta_3_items( + self, mock_vllm_config_producer, mock_request_with_3_mm + ): + """Test metadata building for 3 MM items.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + # Setup state for all 3 items + for i in range(3): + connector.update_state_after_alloc(mock_request_with_3_mm, index=i) + + # Build metadata + scheduler_output = Mock(spec=SchedulerOutput) + metadata = connector.build_connector_meta(scheduler_output) + + # Assert + assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert len(metadata.mm_datas) == 3 + assert metadata.mm_datas[0].mm_hash == "img_hash_1" + assert metadata.mm_datas[0].num_token == 100 + assert metadata.mm_datas[1].mm_hash == "img_hash_2" + assert metadata.mm_datas[1].num_token == 150 + assert metadata.mm_datas[2].mm_hash == "img_hash_3" + assert metadata.mm_datas[2].num_token == 200 + + # State should be cleared after building + assert len(connector._mm_datas_need_loads) == 0 + + def test_build_connector_meta_empty(self, mock_vllm_config_producer): + """Test metadata building with empty state.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + scheduler_output = Mock(spec=SchedulerOutput) + metadata = connector.build_connector_meta(scheduler_output) + + assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert len(metadata.mm_datas) == 0 + + def test_state_cleared_after_metadata_build( + self, mock_vllm_config_producer, mock_request_with_3_mm + ): + """Test that state is properly cleared after building metadata.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + # Add state + for i in range(3): + connector.update_state_after_alloc(mock_request_with_3_mm, index=i) + assert len(connector._mm_datas_need_loads) == 3 + + # Build metadata (should clear state) + scheduler_output = Mock(spec=SchedulerOutput) + connector.build_connector_meta(scheduler_output) + + # State should be empty + assert len(connector._mm_datas_need_loads) == 0 + + # Build again should return empty metadata + metadata2 = connector.build_connector_meta(scheduler_output) + assert len(metadata2.mm_datas) == 0 + + +class TestCacheSaving: + """Test encoder cache saving (producer only).""" + + def test_save_caches_producer_3_items( + self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage + ): + """Test cache saving as producer for 3 different MM items.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + # Create and save 3 different caches + mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features] + encoder_cache: dict[str, torch.Tensor] = {} + + for mm_hash in mm_hashes: + encoder_cache[mm_hash] = torch.randn(10, 768) + connector.save_caches(encoder_cache, mm_hash) + + # Verify all files exist using has_caches + result = connector.has_caches(mock_request_with_3_mm) + assert all(result), f"Not all caches were saved: {result}" + + # Verify each file's content + for mm_hash in mm_hashes: + filename = connector._generate_filename_debug(mm_hash) + loaded = safetensors.torch.load_file(filename) + assert "ec_cache" in loaded + assert torch.allclose(loaded["ec_cache"], encoder_cache[mm_hash].cpu()) + + def test_save_caches_consumer_skips(self, mock_vllm_config_consumer): + """Test cache saving is skipped for consumer.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + mm_hash = "test_hash_consumer" + encoder_cache = {mm_hash: torch.randn(10, 768)} + + # Save should not raise but also not create file + connector.save_caches(encoder_cache, mm_hash) + + # Verify file doesn't exist using has_caches + mock_request = MockRequest("req_consumer", [mm_hash], [10]) + result = connector.has_caches(mock_request) + assert not result[0], "Consumer should not save caches" + + +class TestCacheLoading: + """Test encoder cache loading (consumer).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_start_load_caches_consumer_3_items( + self, + mock_vllm_config_producer, + mock_vllm_config_consumer, + mock_request_with_3_mm, + temp_storage, + ): + """Test consumer loads 3 caches from storage.""" + # First, create producer to save caches + producer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + # Producer saves 3 caches + mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features] + saved_caches = {} + for mm_hash in mm_hashes: + saved_caches[mm_hash] = torch.randn(10, 768) + producer.save_caches(saved_caches, mm_hash) + + # Now consumer loads + consumer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + # Setup metadata for all 3 + metadata = ECSharedStorageConnectorMetadata() + for mm_hash in mm_hashes: + metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100)) + consumer.bind_connector_metadata(metadata) + + # Load + encoder_cache: dict[str, torch.Tensor] = {} + consumer.start_load_caches(encoder_cache=encoder_cache) + + # Verify all 3 loaded + assert len(encoder_cache) == 3 + for mm_hash in mm_hashes: + assert mm_hash in encoder_cache, f"{mm_hash} missing in encoder_cache" + assert encoder_cache[mm_hash].is_cuda, ( + f"{mm_hash} cache is in {encoder_cache[mm_hash].device}" + ) + assert torch.allclose( + encoder_cache[mm_hash].cpu(), saved_caches[mm_hash] + ), f"{mm_hash} cache saved and loaded tesnor are not the same" + + def test_start_load_caches_skip_existing( + self, mock_vllm_config_producer, mock_vllm_config_consumer, temp_storage + ): + """Test cache loading skips already cached items.""" + # Setup: producer saves cache + producer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + mm_hash = "existing_hash" + saved_cache = torch.randn(10, 768) + producer.save_caches({mm_hash: saved_cache}, mm_hash) + + # Consumer setup + consumer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + metadata = ECSharedStorageConnectorMetadata() + metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100)) + consumer.bind_connector_metadata(metadata) + + # Pre-populate encoder_cache with different value + existing_cache = torch.randn(5, 512) + encoder_cache = {mm_hash: existing_cache} + + # Load (should skip since already exists) + with patch("safetensors.torch.load_file") as mock_load: + consumer.start_load_caches(encoder_cache=encoder_cache) + # Should not call load_file since cache exists + mock_load.assert_not_called() + + # Verify original cache unchanged + assert torch.equal(encoder_cache[mm_hash], existing_cache) + + def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer): + """Test loading with empty metadata does nothing.""" + consumer = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + # Setup empty metadata + metadata = ECSharedStorageConnectorMetadata() + consumer.bind_connector_metadata(metadata) + + # Load (should not raise) + encoder_cache: dict[str, torch.Tensor] = {} + consumer.start_load_caches(encoder_cache=encoder_cache) + + # Cache should remain empty + assert len(encoder_cache) == 0 + + +class TestFilenameGeneration: + """Test filename and path generation.""" + + def test_generate_foldername(self, mock_vllm_config_producer, temp_storage): + """Test folder name generation.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + mm_hash = "test_folder_hash" + folder = connector._generate_foldername_debug(mm_hash) + + assert folder == os.path.join(temp_storage, mm_hash) + assert os.path.isdir(folder) # Should be created + + def test_generate_filename(self, mock_vllm_config_producer, temp_storage): + """Test filename generation.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + mm_hash = "test_file_hash" + filename = connector._generate_filename_debug(mm_hash) + + expected = os.path.join(temp_storage, mm_hash, "encoder_cache.safetensors") + assert filename == expected + assert os.path.isdir(os.path.dirname(filename)) # Folder created + + def test_generate_filename_consistency(self, mock_vllm_config_producer): + """Test filename generation is consistent.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + mm_hash = "consistency_hash" + filename1 = connector._generate_filename_debug(mm_hash) + filename2 = connector._generate_filename_debug(mm_hash) + + assert filename1 == filename2 + + +class TestMetadataBindingLifecycle: + """Test metadata binding and clearing lifecycle.""" + + def test_bind_connector_metadata(self, mock_vllm_config_consumer): + """Test binding connector metadata.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + metadata = ECSharedStorageConnectorMetadata() + metadata.add_mm_data(MMMeta.make_meta("hash_1", 100)) + + connector.bind_connector_metadata(metadata) + + assert connector._connector_metadata is metadata + + def test_clear_connector_metadata(self, mock_vllm_config_consumer): + """Test clearing connector metadata.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + metadata = ECSharedStorageConnectorMetadata() + connector.bind_connector_metadata(metadata) + + connector.clear_connector_metadata() + + assert connector._connector_metadata is None + + def test_get_connector_metadata(self, mock_vllm_config_consumer): + """Test getting connector metadata.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + metadata = ECSharedStorageConnectorMetadata() + connector.bind_connector_metadata(metadata) + + retrieved = connector._get_connector_metadata() + + assert retrieved is metadata + + def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer): + """Test getting metadata when not set raises.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + with pytest.raises(AssertionError): + connector._get_connector_metadata() + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_save_empty_cache(self, mock_vllm_config_producer): + """Test saving empty tensor.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.WORKER, + ) + + mm_hash = "empty_hash" + encoder_cache = {mm_hash: torch.empty(0)} + + # Should not raise + connector.save_caches(encoder_cache, mm_hash) + + def test_load_nonexistent_cache(self, mock_vllm_config_consumer): + """Test loading cache that doesn't exist raises error.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_consumer, + role=ECConnectorRole.WORKER, + ) + + metadata = ECSharedStorageConnectorMetadata() + metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100)) + connector.bind_connector_metadata(metadata) + + encoder_cache: dict[str, torch.Tensor] = {} + + # Should raise FileNotFoundError + with pytest.raises(FileNotFoundError): + connector.start_load_caches(encoder_cache=encoder_cache) + + def test_has_caches_empty_request(self, mock_vllm_config_producer): + """Test has_caches with request that has no MM data.""" + connector = ECSharedStorageConnector( + vllm_config=mock_vllm_config_producer, + role=ECConnectorRole.SCHEDULER, + ) + + mock_request = MockRequest("req_empty", [], []) + + result = connector.has_caches(mock_request) + + assert len(result) == 0 + assert result == [] diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 84441aa7d28ca..4e852dca95eb0 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -10,6 +10,14 @@ import pytest from transformers import AutoTokenizer from vllm import SamplingParams +from vllm.config import ( + CacheConfig, + ECTransferConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.utils.torch_utils import set_default_torch_num_threads @@ -450,3 +458,141 @@ def test_engine_core_invalid_request_id_type(): engine_core.add_request(*engine_core.preprocess_add_request(valid_request)) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 + + +@create_new_process_for_each_test() +@pytest.mark.parametrize( + ("ec_role", "gpu_memory_utilization", "enable_prefix_caching"), + [ + ("ec_producer", 0.01, False), + # NOTE: ec_producer never allows prefix caching + ("ec_consumer", 0.7, True), + ("ec_consumer", 0.7, False), + ], +) +@pytest.mark.parametrize("use_kv_connector", [False, True]) +def test_encoder_instance_zero_kv_cache( + ec_role: str, + gpu_memory_utilization: float, + enable_prefix_caching: bool, + use_kv_connector: bool, +): + """EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests + + This test verifies encoder-only instance initializes with 0 KV cache blocks. + Under EPD disagg mode, Encoder instances (EC producer role) only execute + vision encoder, so they don't need KV cache for text generation. + """ + # Form vllm config + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + disable_hybrid_kv_cache_manager=True, + ) + model_config = ModelConfig( + model="llava-hf/llava-1.5-7b-hf", # Multimodal model + enforce_eager=True, + trust_remote_code=True, + dtype="float16", + seed=42, + ) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=enable_prefix_caching, + ) + kv_transfer_config = ( + KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) + if use_kv_connector + else None + ) + ec_transfer_config = ECTransferConfig( + ec_connector="ECSharedStorageConnector", + ec_role=ec_role, + ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"}, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + kv_transfer_config=kv_transfer_config, + ec_transfer_config=ec_transfer_config, + ) + + executor_class = Executor.get_class(vllm_config) + print(f"executor_class: {executor_class}") + + with set_default_torch_num_threads(1): + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) + + # Check encoder cache manager exists + assert engine_core.scheduler.encoder_cache_manager is not None, ( + "encoder_cache_manager should exist" + ) + + if ec_role == "ec_producer": + # Check 1: num_blocks should be 0 + # NOTE: num_blocks=1 as BlockPool always needs a null_block. + kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config + print(f"kv_cache_config: {kv_cache_config}") + assert kv_cache_config.num_blocks == 1, ( + f"ec_producer should only have 1 KV blocks, " + f"got {kv_cache_config.num_blocks}" + ) + + # Check 2: kv_cache_groups should be empty + assert len(kv_cache_config.kv_cache_groups) == 0, ( + f"ec_producer should have 0 KV cache groups, " + f"got {len(kv_cache_config.kv_cache_groups)}" + ) + + # Check 3: kv_cache_tensors should be empty + assert len(kv_cache_config.kv_cache_tensors) == 0, ( + f"Encoder instance should have 0 KV cache tensors, " + f"got {len(kv_cache_config.kv_cache_tensors)}" + ) + + # Check 4: Verify EC connector is initialized and is producer + assert engine_core.scheduler.ec_connector is not None, ( + "Encoder instance should have EC connector" + ) + assert engine_core.scheduler.ec_connector.is_producer, ( + "Encoder instance EC connector should be producer" + ) + + # Check 5: Verify chunked prefill is disabled + assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( + "Encoder instance should disable chunked prefill (no KV cache)" + ) + + elif ec_role == "ec_consumer": + # Check 1: num_blocks should be > 1 + kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config + print(f"kv_cache_config: {kv_cache_config}") + assert kv_cache_config.num_blocks > 1, ( + f"ec_consumer should have >1 KV blocks, got {kv_cache_config.num_blocks}" + ) + + # Check 2: kv_cache_groups should NOT be empty + assert len(kv_cache_config.kv_cache_groups) > 0, ( + f"ec_consumer should have KV cache groups, " + f"got {len(kv_cache_config.kv_cache_groups)}" + ) + + # Check 3: Verify EC connector is consumer + assert engine_core.scheduler.ec_connector is not None, ( + "Consumer instance should have EC connector" + ) + assert not engine_core.scheduler.ec_connector.is_producer, ( + "Consumer instance EC connector should be consumer" + ) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7f1cc52024205..dd76a722106ef 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -9,6 +9,7 @@ from vllm.config.compilation import ( PassConfig, ) from vllm.config.device import DeviceConfig +from vllm.config.ec_transfer import ECTransferConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig @@ -54,6 +55,8 @@ __all__ = [ "PassConfig", # From vllm.config.device "DeviceConfig", + # From vllm.config.ec_transfer + "ECTransferConfig", # From vllm.config.kv_events "KVEventsConfig", # From vllm.config.kv_transfer diff --git a/vllm/config/ec_transfer.py b/vllm/config/ec_transfer.py new file mode 100644 index 0000000000000..d95236f818abb --- /dev/null +++ b/vllm/config/ec_transfer.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib +import uuid +from dataclasses import field +from typing import Any, Literal, get_args + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +ECProducer = Literal["ec_producer"] +ECConsumer = Literal["ec_consumer"] +ECRole = Literal[ECProducer, ECConsumer] + + +@config +@dataclass +class ECTransferConfig: + """Configuration for distributed EC cache transfer.""" + + ec_connector: str | None = None + """The EC connector for vLLM to transmit EC caches between vLLM instances. + """ + + engine_id: str | None = None + """The engine id for EC transfers.""" + + ec_buffer_device: str | None = "cuda" + """The device used by ec connector to buffer the EC cache. + Currently only support 'cuda'.""" + + ec_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + ec_role: ECRole | None = None + """Whether this vLLM instance produces, consumes EC cache, or both. Choices + are 'ec_producer', 'ec_consumer'.""" + + ec_rank: int | None = None + """The rank of this vLLM instance in the EC cache transfer. Typical value: + 0 for encoder, 1 for pd instance. + Currently only 1P1D is supported.""" + + ec_parallel_size: int = 1 + """The number of parallel instances for EC cache transfer. For + PyNcclConnector, this should be 2.""" + + ec_ip: str = "127.0.0.1" + """The EC connector ip, used to build distributed connection.""" + + ec_port: int = 14579 + """The EC connector port, used to build distributed connection.""" + + ec_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + ec_connector_module_path: str | None = None + """The Python module path to dynamically load the EC connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.ec_role is not None and self.ec_role not in get_args(ECRole): + raise ValueError( + f"Unsupported ec_role: {self.ec_role}. " + f"Supported roles are {get_args(ECRole)}" + ) + + if self.ec_connector is not None and self.ec_role is None: + raise ValueError( + "Please specify ec_role when ec_connector " + f"is set, supported roles are {get_args(ECRole)}" + ) + + @property + def is_ec_transfer_instance(self) -> bool: + return self.ec_connector is not None and self.ec_role in get_args(ECRole) + + @property + def is_ec_producer(self) -> bool: + return self.ec_connector is not None and self.ec_role in get_args(ECProducer) + + @property + def is_ec_consumer(self) -> bool: + return self.ec_connector is not None and self.ec_role in get_args(ECConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.ec_connector_extra_config.get(key, default) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index df9a1fd08af6f..60458b26944a5 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -28,6 +28,7 @@ from vllm.utils import random_uuid from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig +from .ec_transfer import ECTransferConfig from .kv_events import KVEventsConfig from .kv_transfer import KVTransferConfig from .load import LoadConfig @@ -103,6 +104,8 @@ class VllmConfig: """The configurations for distributed KV cache transfer.""" kv_events_config: KVEventsConfig | None = None """The configurations for event publishing.""" + ec_transfer_config: ECTransferConfig | None = None + """The configurations for distributed EC cache transfer.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -183,6 +186,10 @@ class VllmConfig: vllm_factors.append(self.kv_transfer_config.compute_hash()) else: vllm_factors.append("None") + if self.ec_transfer_config: + vllm_factors.append(self.ec_transfer_config.compute_hash()) + else: + vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): additional_config_hash = hashlib.md5( diff --git a/vllm/distributed/ec_transfer/__init__.py b/vllm/distributed/ec_transfer/__init__.py new file mode 100644 index 0000000000000..0decfd143e343 --- /dev/null +++ b/vllm/distributed/ec_transfer/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.distributed.ec_transfer.ec_transfer_state import ( + ensure_ec_transfer_initialized, + get_ec_transfer, + has_ec_transfer, +) + +__all__ = [ + "get_ec_transfer", + "ensure_ec_transfer_initialized", + "has_ec_transfer", +] diff --git a/vllm/distributed/ec_transfer/ec_connector/__init__.py b/vllm/distributed/ec_transfer/ec_connector/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py new file mode 100644 index 0000000000000..2b7b14d89b8a1 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +ECConnectorBase Class for Distributed Encoder Cache & +P2P Encoder cache communication in V1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save Encoder cache. + check_caches_exist() - Check whether Encoder cache of requests exist + update_state_after_alloc() - update ECConnector state after + allocate. This will decide to load the cache or not + request_finished() - called when a request is finished, + free the cache with the requests + + Worker-side: runs in each worker, loads/saves Encoder Cache to/from + the Connector based on the metadata. + start_load_ec() - starts loading all ECs (maybe async) + wait_for_save() - blocks until all saves are done + + get_finished() - called with ids of finished requests, returns + ids of requests that have completed async sending/recving. +""" + +import enum +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ECConnectorOutput + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class ECConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +class ECConnectorMetadata(ABC): # noqa: B024 + """ + Abstract Metadata used to communicate between the + Scheduler ECConnector and Worker ECConnector. + """ + + pass + + +class ECConnectorBase(ABC): + def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): + self._connector_metadata: ECConnectorMetadata | None = None + self._vllm_config = vllm_config + self._role = role + if vllm_config.ec_transfer_config is not None: + self._is_producer = vllm_config.ec_transfer_config.is_ec_producer + else: + raise ValueError("ec_transfer_config must be set for ECConnectorBase") + + @property + def role(self) -> ECConnectorRole: + return self._role + + @property + def is_producer(self) -> bool: + return self._is_producer + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + EC cache loading. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = None + + def _get_connector_metadata(self) -> ECConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + + # Should only be called while set to valid metadata. + assert self._connector_metadata is not None + return self._connector_metadata + + def register_caches( + self, + ec_caches: dict[str, torch.Tensor], + ): + """ + Initialize with the EC caches. + Args: + ec_caches: dictionary of encoder cache + """ + # TODO: Implement this later for P2P feature + return + + @abstractmethod + def start_load_caches( + self, encoder_cache: dict[str, torch.Tensor], **kwargs + ) -> None: + """ + Start loading the cache from the connector into vLLM's encoder cache. + + This method loads the encoder cache based on metadata provided by the scheduler. + It is called before `_gather_mm_embeddings` for the EC Connector. For EC, + the `encoder_cache` and `mm_hash` are stored in `kwargs`. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + kwargs (dict): Additional keyword arguments for the connector. + """ + pass + + @abstractmethod + def save_caches( + self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs + ) -> None: + """ + Save the encoder cache to the connector. + + This method saves the encoder cache from the worker's local storage + to shared storage or another external connector. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + mm_hash (str): The hash of the multimodal data whose cache is being saved. + kwargs (dict): Additional keyword arguments for the connector. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens on the worker. + The scheduler process (via the Executors) will use this output + to track which workers are done. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + # ============================== + # Scheduler-side methods + # ============================== + + @abstractmethod + def has_caches( + self, + request: "Request", + ) -> list[bool]: + """ + Check if encoder cache exists for each mm data of requests + + Args: + request (Request): the request object. + + Returns: + A list bool where ith value is True if cache exist for + ith mm_data of requests + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", index: int): + """ + Update ECConnector state to decide allocate cache for requests + + Args: + request (Request): the request object. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> ECConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass + + def update_connector_output(self, connector_output: ECConnectorOutput): + """ + Update ECConnector state from worker-side connectors output. + + Args: + connector_output (ECConnectorOutput): the worker-side + connectors output. + """ + return + + def request_finished( + self, request: "Request" + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called when a request has finished, before its encoder cache is freed. + + Returns: + True if the request is being saved/sent asynchronously and cached + should not be freed until the request_id is returned from + get_finished(). + """ + return False, None diff --git a/vllm/distributed/ec_transfer/ec_connector/factory.py b/vllm/distributed/ec_transfer/ec_connector/factory.py new file mode 100644 index 0000000000000..bfdf51d775bda --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/factory.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +from collections.abc import Callable +from typing import TYPE_CHECKING + +# yapf: disable +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorBase, + ECConnectorRole, +) +from vllm.logger import init_logger + +# yapf: enable + +if TYPE_CHECKING: + from vllm.config import ECTransferConfig, VllmConfig + +logger = init_logger(__name__) + + +class ECConnectorFactory: + _registry: dict[str, Callable[[], type[ECConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[ECConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector( + cls, + config: "VllmConfig", + role: ECConnectorRole, + ) -> ECConnectorBase: + ec_transfer_config = config.ec_transfer_config + if ec_transfer_config is None: + raise ValueError("ec_transfer_config must be set to create a connector") + connector_cls = cls.get_connector_class(ec_transfer_config) + logger.info( + "Creating connector with name: %s and engine_id: %s", + connector_cls.__name__, + ec_transfer_config.engine_id, + ) + # Connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + return connector_cls(config, role) + + @classmethod + def get_connector_class( + cls, ec_transfer_config: "ECTransferConfig" + ) -> type[ECConnectorBase]: + """Get the connector class by name.""" + connector_name = ec_transfer_config.ec_connector + if connector_name is None: + raise ValueError("EC connect must not be None") + elif connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = ec_transfer_config.ec_connector_module_path + if connector_module_path is None: + raise ValueError(f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + return connector_cls + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. + +ECConnectorFactory.register_connector( + "ECSharedStorageConnector", + "vllm.distributed.ec_transfer.ec_connector.shared_storage_connector", + "ECSharedStorageConnector", +) diff --git a/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py b/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py new file mode 100644 index 0000000000000..c8388141dcc97 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import safetensors + +from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorBase, + ECConnectorMetadata, + ECConnectorRole, +) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MMMeta: + mm_hash: str + num_token: int + + @staticmethod + def make_meta(mm_hash, num_token) -> "MMMeta": + return MMMeta(mm_hash=mm_hash, num_token=num_token) + + +@dataclass +class ECSharedStorageConnectorMetadata(ECConnectorMetadata): + mm_datas: list[MMMeta] + + def __init__(self): + self.mm_datas = [] + + def add_mm_data(self, mm_data: MMMeta): + self.mm_datas.append(mm_data) + + +class ECSharedStorageConnector(ECConnectorBase): + # NOTE: This is Simple debug implementation of the EC connector. + # It save / load the EC cache to / from the disk. + + def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + # req_id -> index + self._mm_datas_need_loads: dict[str, int] = {} + transfer_config = vllm_config.ec_transfer_config + if transfer_config is not None: + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp" + ) + logger.debug(transfer_config) + logger.debug("Shared storage path is %s", self._storage_path) + else: + raise ValueError("ec_transfer_config must be set for ECConnectorBase") + + def start_load_caches(self, encoder_cache, **kwargs) -> None: + """ + Start loading the cache from the connector into vLLM's encoder cache. + + This method loads the encoder cache based on metadata provided by the scheduler. + It is called before `_gather_mm_embeddings` for the EC Connector. For EC, + the `encoder_cache` and `mm_hash` are stored in `kwargs`. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + kwargs (dict): Additional keyword arguments for the connector. + """ + + # Get the metadata + metadata: ECConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert encoder_cache is not None + if metadata is None: + logger.warning( + ( + "In connector.start_load_caches, ", + "but the connector metadata is None", + ) + ) + return + # Load the EC for each mm data + for mm_data in metadata.mm_datas: + if mm_data.mm_hash in encoder_cache: + continue + filename = self._generate_filename_debug(mm_data.mm_hash) + ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda() + encoder_cache[mm_data.mm_hash] = ec_cache + logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash) + + def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None: + """ + Save the encoder cache to the connector. + + This method saves the encoder cache from the worker's local storage + to shared storage or another external connector. + + Args: + encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal + data hashes (`mm_hash`) to encoder cache tensors. + mm_hash (str): The hash of the multimodal data whose cache is being saved. + kwargs (dict): Additional keyword arguments for the connector. + """ + # Return if it is PD Instance + if not self.is_producer: + return + filename = self._generate_filename_debug(mm_hash) + ec_cache = encoder_cache[mm_hash] + tensors = {"ec_cache": ec_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + logger.debug("Save cache successful for mm_hash %s", mm_hash) + + def has_caches( + self, + request: "Request", + ) -> list[bool]: + """ + Check if cache exist externally for each mm_data of request + + Args: + request (Request): the request object. + + Returns: + List of bool indicate that ith mm_data exist in cache or not + """ + result = [] + for feature in request.mm_features: + result.append(self._found_match_for_mm_data(feature.identifier)) + return result + + def update_state_after_alloc( + self, + request: "Request", + index: int, + ) -> None: + """ + Update ECConnector state after encoder cache allocation. + """ + mm_hash = request.mm_features[index].identifier + num_encoder_token = request.get_num_encoder_tokens(index) + # Insert mm_hash only if this block has not been recorded yet. + self._mm_datas_need_loads[mm_hash] = num_encoder_token + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> ECConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + This only build for load mm_data only + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = ECSharedStorageConnectorMetadata() + for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): + meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) + self._mm_datas_need_loads.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_mm_data(self, mm_hash) -> bool: + """Check if the cache is hit for the request.""" + filename = self._generate_filename_debug(mm_hash) + return os.path.exists(filename) + + def _generate_foldername_debug( + self, + mm_hash: str, + create_folder: bool = True, # <- now defaults to True + ) -> str: + """ + Return the folder in which the cache for this mm_hash lives. + If `create_folder` is True (default) the directory is created + recursively the first time it is needed. + """ + foldername = os.path.join(self._storage_path, mm_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug(self, mm_hash: str) -> str: + """ + Return the full path of the safetensors file for this mm_hash. + Ensures the parent directory exists because + `_generate_foldername_debug` is called with its default + (`create_folder=True`). + """ + foldername = self._generate_foldername_debug(mm_hash) # <- folder auto-created + return os.path.join(foldername, "encoder_cache.safetensors") diff --git a/vllm/distributed/ec_transfer/ec_transfer_state.py b/vllm/distributed/ec_transfer/ec_transfer_state.py new file mode 100644 index 0000000000000..95f516129e0c3 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_transfer_state.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + +from vllm import envs +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorBase, + ECConnectorRole, +) +from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_EC_CONNECTOR_AGENT: ECConnectorBase | None = None + + +def get_ec_transfer() -> ECConnectorBase: + assert _EC_CONNECTOR_AGENT is not None, "disaggregated EC cache is not initialized" + return _EC_CONNECTOR_AGENT + + +def has_ec_transfer() -> bool: + return _EC_CONNECTOR_AGENT is not None + + +def ensure_ec_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize EC cache connector. + """ + + global _EC_CONNECTOR_AGENT + + if vllm_config.ec_transfer_config is None: + return + + if ( + vllm_config.ec_transfer_config.is_ec_transfer_instance + and _EC_CONNECTOR_AGENT is None + ): + if envs.VLLM_USE_V1: + _EC_CONNECTOR_AGENT = ECConnectorFactory.create_connector( + config=vllm_config, role=ECConnectorRole.WORKER + ) + else: + raise ValueError("V0 is no longer supported") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0a82745bf55ab..13c7704f5bf3d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,6 +38,7 @@ from vllm.config import ( CompilationConfig, ConfigType, DeviceConfig, + ECTransferConfig, EPLBConfig, KVEventsConfig, KVTransferConfig, @@ -527,6 +528,8 @@ class EngineArgs: kv_transfer_config: KVTransferConfig | None = None kv_events_config: KVEventsConfig | None = None + ec_transfer_config: ECTransferConfig | None = None + generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode override_generation_config: dict[str, Any] = get_field( @@ -1105,6 +1108,9 @@ class EngineArgs: "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] ) vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument( + "--ec-transfer-config", **vllm_kwargs["ec_transfer_config"] + ) vllm_group.add_argument( "--compilation-config", "-O", **vllm_kwargs["compilation_config"] ) @@ -1676,6 +1682,7 @@ class EngineArgs: compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, + ec_transfer_config=self.ec_transfer_config, additional_config=self.additional_config, ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 28792338f036f..95f5982bc8c7b 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -49,10 +49,18 @@ def kernel_warmup(worker: "Worker"): except NotImplementedError: return False - if not worker.model_runner.is_pooling_model and all( - _is_flashinfer_backend(group.backend) - for groups in worker.model_runner.attn_groups - for group in groups + # NOTE: we add check for empty attn_groups to avoid errors when + # deploying models such as E instances and encoder-only models. + # As for those models, worker.model_runner.attn_groups is empty. + # This change is made during EPD feature development. + if ( + not worker.model_runner.is_pooling_model + and worker.model_runner.attn_groups + and all( + _is_flashinfer_backend(group.backend) + for groups in worker.model_runner.attn_groups + for group in groups + ) ): logger.info("Warming up FlashInfer attention.") # Warmup with mixed batch containing both prefill and decode tokens diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 866136648bcba..20fdb3446404b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: import numpy.typing as npt import torch + from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec @@ -21,6 +22,7 @@ if TYPE_CHECKING: from vllm.sampling_params import SamplingParams from vllm.v1.request import Request else: + ECConnectorMetadata = object KVConnectorMetadata = object LoRARequest = object MultiModalFeatureSpec = object @@ -188,6 +190,9 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: KVConnectorMetadata | None = None + # EC Cache Connector metadata + ec_connector_metadata: ECConnectorMetadata | None = None + @dataclass class GrammarOutput: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 46dc1071b8395..8455746cd56d2 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,6 +7,11 @@ from collections.abc import Iterable from typing import Any from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorMetadata, + ECConnectorRole, +) +from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1 import ( @@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorRole, SupportsHMA, ) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry @@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface): self.kv_events_config, self.parallel_config.data_parallel_rank, ) + self.ec_connector = None + if self.vllm_config.ec_transfer_config is not None: + self.ec_connector = ECConnectorFactory.create_connector( + config=self.vllm_config, role=ECConnectorRole.SCHEDULER + ) num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 @@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface): # Schedule encoder inputs. encoder_inputs_to_schedule = None + external_load_encoder_input: list[int] = [] new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: ( encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, + external_load_encoder_input, ) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, @@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface): for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() @@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface): num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None + external_load_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget # KVTransfer: loading remote KV, do not allocate for new work. @@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface): encoder_inputs_to_schedule, num_new_tokens, new_encoder_compute_budget, + external_load_encoder_input, ) = self._try_schedule_encoder_inputs( request, num_computed_tokens, @@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface): for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget - + # Allocate for external load encoder cache + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) @@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface): # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs # Since some requests in the RUNNING queue may not be scheduled in @@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface): # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) + meta: KVConnectorMetadata = self.connector.build_connector_meta( + scheduler_output + ) scheduler_output.kv_connector_metadata = meta + + # Build the connector meta for ECConnector + if self.ec_connector is not None: + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( + scheduler_output + ) + scheduler_output.ec_connector_metadata = ec_meta + with record_function_or_nullcontext("schedule: update_after_schedule"): self._update_after_schedule(scheduler_output) return scheduler_output @@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface): num_computed_tokens: int, num_new_tokens: int, encoder_compute_budget: int, - ) -> tuple[list[int], int, int]: + ) -> tuple[list[int], int, int, list[int]]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface): in this step, i.e., [num_computed_tokens, num_computed_tokens + num_new_tokens). - It is not already computed and stored in the encoder cache. + - It is not exist on remote encoder cache (via ECConnector) - There is sufficient encoder token budget to process it. - The encoder cache has space to store it. @@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface): blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_compute_budget + return [], num_new_tokens, encoder_compute_budget, [] encoder_inputs_to_schedule: list[int] = [] mm_features = request.mm_features assert mm_features is not None assert len(mm_features) > 0 + external_load_encoder_input = [] + # Check remote cache first + if self.ec_connector is not None: + remote_cache_has_item = self.ec_connector.has_caches(request) # NOTE: since scheduler operates on the request level (possibly with # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. @@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface): num_new_tokens = 0 break + if self.ec_connector is not None and remote_cache_has_item[i]: + mm_hashes_to_schedule.add(request.mm_features[i].identifier) + external_load_encoder_input.append(i) + num_tokens_to_schedule += num_encoder_tokens + continue + num_tokens_to_schedule += num_encoder_tokens encoder_compute_budget -= num_encoder_tokens mm_hashes_to_schedule.add(request.mm_features[i].identifier) @@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface): encoder_inputs_to_schedule, num_new_tokens, encoder_compute_budget, + external_load_encoder_input, ) def get_grammar_bitmask( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 5f65e4ee0d1f3..e32d5bb608b1d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, NamedTuple import numpy as np import torch +from vllm.v1.core.sched.output import SchedulerOutput + if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats else: @@ -136,6 +138,13 @@ class KVConnectorOutput: ) +@dataclass +class ECConnectorOutput: + # [mm_hash] + finished_sending: set[str] | None = None + finished_recving: set[str] | None = None + + # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass @@ -167,6 +176,8 @@ class ModelRunnerOutput: kv_connector_output: KVConnectorOutput | None = None + ec_connector_output: ECConnectorOutput | None = None + # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None @@ -192,6 +203,41 @@ class DraftTokenIds: draft_token_ids: list[list[int]] +def make_empty_encoder_model_runner_output( + scheduler_output: "SchedulerOutput", +) -> ModelRunnerOutput: + """ + Create a ModelRunnerOutput stub that contains the correct + per-request bookkeeping but no generated data yet. + """ + if not scheduler_output.num_scheduled_tokens: + return EMPTY_MODEL_RUNNER_OUTPUT + + # Convert to list so we get a deterministic, indexable sequence + req_ids: list[str] = list(scheduler_output.num_scheduled_tokens.keys()) + + # Give every request its own contiguous index + req_id_to_index: dict[str, int] = {rid: idx for idx, rid in enumerate(req_ids)} + + # No tokens generated yet ⇒ one empty list per request + sampled_token_ids: list[list[int]] = [[0] for _ in req_ids] + + # Pooler outputs are not available yet ⇒ use None placeholders + pooler_output: list[torch.Tensor | None] = [None for _ in req_ids] + + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=None, + ec_connector_output=None, + num_nans_in_logits=None, + ) + + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[], req_id_to_index={}, diff --git a/vllm/v1/worker/ec_connector_model_runner_mixin.py b/vllm/v1/worker/ec_connector_model_runner_mixin.py new file mode 100644 index 0000000000000..00bc909df2975 --- /dev/null +++ b/vllm/v1/worker/ec_connector_model_runner_mixin.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Define EC connector functionality mixin for model runners. +""" + +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import ( + TYPE_CHECKING, # noqa: UP035 +) + +import torch + +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer +from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorBase +from vllm.logger import init_logger +from vllm.v1.outputs import ECConnectorOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +# Defined as a EC connector functionality mixin for ModelRunner (GPU, TPU) +class ECConnectorModelRunnerMixin: + @staticmethod + def maybe_save_ec_to_connector( + encoder_cache: dict[str, torch.Tensor], + mm_hash: str, + ): + if not has_ec_transfer(): + logger.debug("Not have ec transfer please check") + return + connector = get_ec_transfer() + connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash) + + @staticmethod + def get_finished_ec_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[set[str] | None, set[str] | None]: + if has_ec_transfer(): + return get_ec_transfer().get_finished(scheduler_output.finished_req_ids) + return None, None + + @staticmethod + def maybe_get_ec_connector_output( + scheduler_output: "SchedulerOutput", + encoder_cache: dict[str, torch.Tensor], + **kwargs, + ) -> AbstractContextManager[ECConnectorOutput | None]: + return ( + ECConnectorModelRunnerMixin._get_ec_connector_output( + scheduler_output, encoder_cache, **kwargs + ) + if has_ec_transfer() + else nullcontext() + ) + + # This context manager must be used within an active forward context. + # It encapsulates the entire EC conector lifecycle within execute_model + @staticmethod + @contextmanager + def _get_ec_connector_output( + scheduler_output: "SchedulerOutput", + encoder_cache: dict[str, torch.Tensor], + **kwargs, + ) -> Generator[ECConnectorOutput, None, None]: + output = ECConnectorOutput() + + ec_connector = get_ec_transfer() + assert isinstance(ec_connector, ECConnectorBase) + assert scheduler_output.ec_connector_metadata is not None + ec_connector.bind_connector_metadata(scheduler_output.ec_connector_metadata) + + if not ec_connector.is_producer: + ec_connector.start_load_caches(encoder_cache, **kwargs) + + try: + yield output + finally: + output.finished_sending, output.finished_recving = ( + ec_connector.get_finished(scheduler_output.finished_req_ids) + ) + + ec_connector.clear_connector_metadata() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fbd3e5f313167..b14b6b1c3f52e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.config import ( get_layers_from_vllm_config, update_config, ) +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks @@ -114,12 +115,14 @@ from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, + ECConnectorOutput, KVConnectorOutput, LogprobsLists, LogprobsTensors, ModelRunnerOutput, PoolerOutput, SamplerOutput, + make_empty_encoder_model_runner_output, ) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs @@ -134,6 +137,7 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.dp_utils import coordinate_batch_across_dp +from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -237,9 +241,12 @@ class ExecuteModelState(NamedTuple): sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None kv_connector_output: KVConnectorOutput | None + ec_connector_output: ECConnectorOutput | None -class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): +class GPUModelRunner( + LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin +): def __init__( self, vllm_config: VllmConfig, @@ -1873,6 +1880,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): output, is_embed=pos_info.is_embed, ) + logger.debug("Finish execute for mm hash %s", mm_hash) + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) def _gather_mm_embeddings( self, @@ -2191,20 +2200,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch.Tensor, IntermediateTensors | None, dict[str, Any], + ECConnectorOutput | None, ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens is_first_rank = get_pp_group().is_first_rank # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order + ec_connector_output = None + if ( self.supports_mm_inputs and is_first_rank and not self.model_config.is_encoder_decoder ): # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -2284,6 +2300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): positions, intermediate_tensors, model_kwargs, + ec_connector_output, ) def _sample( @@ -2508,6 +2525,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Update persistent batch states. self._update_states(scheduler_output) + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + if not num_scheduled_tokens: if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. @@ -2583,6 +2608,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): positions, intermediate_tensors, model_kwargs, + ec_connector_output, ) = self._preprocess( scheduler_output, num_input_tokens, intermediate_tensors ) @@ -2699,6 +2725,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sample_hidden_states, aux_hidden_states, kv_connector_output, + ec_connector_output, ) return None @@ -2720,6 +2747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sample_hidden_states, aux_hidden_states, kv_connector_output, + ec_connector_output, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -2811,6 +2839,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], kv_connector_output=kv_connector_output, + ec_connector_output=ec_connector_output + if self.supports_mm_inputs + else None, num_nans_in_logits=num_nans_in_logits, ) @@ -4797,7 +4828,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ - + if has_ec_transfer() and get_ec_transfer().is_producer: + return {} kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) for layer_name, attn_module in attn_layers.items(): diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 19061fcffdf1a..2b9d8bb2f25e6 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -20,6 +20,7 @@ from vllm.distributed import ( init_distributed_environment, set_custom_all_reduce, ) +from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.kv_transfer import ( ensure_kv_transfer_initialized, get_kv_transfer_group, @@ -887,3 +888,7 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size, parallel_config.decode_context_parallel_size, ) + + # Init ec connector here before KV caches caches init + # NOTE: We do not init KV caches for Encoder-only instance in EPD disagg mode + ensure_ec_transfer_initialized(vllm_config)