[P/D] NIXL Integration (#17751)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Brent Salisbury <bsalisbu@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Brent Salisbury <bsalisbu@redhat.com>
This commit is contained in:
Robert Shaw 2025-05-12 12:46:16 -04:00 committed by GitHub
parent 05a4324f8e
commit d19110204c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 2723 additions and 108 deletions

View File

@ -214,6 +214,7 @@ steps:
- pytest -v -s v1/worker
- pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode
- pytest -v -s v1/kv_connector/unit
- pytest -v -s v1/test_serial_utils.py
- pytest -v -s v1/test_stats.py
- pytest -v -s v1/test_utils.py

View File

@ -870,7 +870,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
NUM_MATCHED_NEW_TOKENS, False)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
@ -981,7 +981,7 @@ def test_kv_connector_unable_to_allocate():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
NUM_MATCHED_NEW_TOKENS, False)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
@ -1060,7 +1060,7 @@ def test_kv_connector_handles_preemption():
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS)
NUM_MATCHED_NEW_TOKENS, False)
# Create two requests.
# Both can be scheduled at first, but the second request

View File

@ -0,0 +1,171 @@
#!/bin/bash
set -xe
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Arrays to store all hosts and ports
PREFILL_HOSTS=()
PREFILL_PORTS=()
DECODE_HOSTS=()
DECODE_PORTS=()
# Start prefill instances
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)))
# Calculate port number (base port + instance number)
PORT=$((8100 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
--port $PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Store host and port for proxy configuration
PREFILL_HOSTS+=("localhost")
PREFILL_PORTS+=($PORT)
done
# Start decode instances
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)))
# Calculate port number (base port + instance number)
PORT=$((8200 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5659 + i))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
--port $PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Store host and port for proxy configuration
DECODE_HOSTS+=("localhost")
DECODE_PORTS+=($PORT)
done
# Wait for all instances to start
for PORT in "${PREFILL_PORTS[@]}"; do
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PORT
done
for PORT in "${DECODE_PORTS[@]}"; do
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $PORT
done
# Build the command for the proxy server with all the hosts and ports
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
# Add all prefill hosts and ports
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}"
# Add all decode hosts and ports
PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"

View File

@ -0,0 +1,123 @@
#!/bin/bash
set -xe
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Start prefill instance
PREFILL_PORT=8001
BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
--port $PREFILL_PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Start decode instance
DECODE_PORT=8002
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
--port $DECODE_PORT \
--enforce-eager \
--disable-log-requests \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Wait for all instances to start
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $DECODE_PORT
# Build the command for the proxy server with all the hosts and ports
PROXY_PORT=8192
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
import os
import lm_eval
import openai
BASE_URL = "http://localhost:8192/v1"
NUM_CONCURRENT = 100
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
}
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
# Get model name from environment variable
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
def run_simple_prompt():
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
completion = client.completions.create(model=MODEL_NAME,
prompt=SIMPLE_PROMPT)
print("-" * 50)
print(f"Completion results for {MODEL_NAME}:")
print(completion)
print("-" * 50)
def test_accuracy():
"""Run the end to end accuracy test."""
run_simple_prompt()
model_args = (f"model={MODEL_NAME},"
f"base_url={BASE_URL}/completions,"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
if expected_value is None:
print(f"Warning: No expected value found for {MODEL_NAME}. "
"Skipping accuracy check.")
print(f"Measured value: {measured_value}")
return
assert (measured_value - RTOL < expected_value
and measured_value + RTOL > expected_value
), f"Expected: {expected_value} | Measured: {measured_value}"

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
import os
import openai
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
raise ValueError(
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
SHORT_PROMPT = "Red Hat is "
def test_edge_cases():
# Set the OpenAI API key and base URL
decode_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{DECODE_PORT}/v1",
)
prefill_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PREFILL_PORT}/v1",
)
proxy_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://localhost:{PROXY_PORT}/v1",
)
# Get the list of models
models = decode_client.models.list()
MODEL = models.data[0].id
# (1) Check that we can handle a very short prompt,
# less than the length of the block size.
completion = proxy_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=SHORT_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"SMALL PROMPT: {proxy_response=}")
assert proxy_response == prefill_response
# (2) Check that we can handle a full prefix cache
# hit on the D worker but not on the P worker.
# (2a): prime the D worker.
completion = decode_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
decode_response = completion.choices[0].text
# (2b): send via the P/D setup
completion = proxy_client.completions.create(model=MODEL,
prompt=PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
print(f"FULL CACHE HIT: {proxy_response=}")
assert proxy_response == decode_response
# (3) Check that we can handle a partial prefix cache
# hit on the D worker.
completion = proxy_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(model=MODEL,
prompt=LONG_PROMPT,
temperature=0)
prefill_response = completion.choices[0].text
print(f"PARTIAL CACHE HIT: {proxy_response=}")
assert proxy_response == prefill_response

View File

@ -0,0 +1,260 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import itertools
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f'http://{host}:{port}/v1'
app.state.prefill_clients.append({
'client':
httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
'host':
host,
'port':
port,
'id':
i
})
# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f'http://{host}:{port}/v1'
app.state.decode_clients.append({
'client':
httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
'host':
host,
'port':
port,
'id':
i
})
# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(
range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(
range(len(app.state.decode_clients)))
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients.")
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info['client'].aclose()
for client_info in app.state.decode_clients:
await client_info['client'].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
# For prefiller instances
parser.add_argument("--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
"--prefiller-port",
type=int,
nargs="+",
default=[8100])
# For decoder instances
parser.add_argument("--decoder-hosts",
"--decoder-host",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports",
"--decoder-port",
type=int,
nargs="+",
default=[8200])
args = parser.parse_args()
# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
# Create tuples of (host, port) for each service type
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == 'prefill':
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == 'decode':
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
response = await client_info['client'].post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client_info: dict, endpoint: str,
req_data: dict, request_id: str):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async with client_info['client'].stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
@app.post("/v1/completions")
async def handle_completions(request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service
response = await send_request_to_service(prefill_client_info,
"/completions", req_data,
request_id)
# Extract the needed fields
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
"/completions",
req_data,
request_id=request_id):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients)
}
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata)
from .utils import create_request, create_scheduler, create_vllm_config
def test_basic_inferface():
"""Unit test for basic NixlConnector interface functionality."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
scheduler.add_request(request)
# Remote Prefill, triggers NixlConnectorMetdata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
assert len(kv_connector_metadata.requests) == 1
assert request_id in kv_connector_metadata.requests
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.
single_type_manager.req_to_blocks[request_id]):
assert block_id == block.block_id
def test_prompt_less_than_block_size():
"""
Test that we can handle case where prompt is < block.
In this case, the P worker will send empty remote_block_ids.
The D worker should not schedule an async read in this case,
since there is nothing to pull.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Half of a block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
# Request will have 0 remote blocks.
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
num_remote_blocks=0)
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
# This request should not have to read async.
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
assert len(kv_connector_metadata.requests) == 0
# This request should be scheduled regularly.
assert len(scheduler_output.scheduled_new_reqs) == 1

View File

@ -0,0 +1,181 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
create_request, create_scheduler, create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs.outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and in Persistent Batch ...
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
for block in blocks:
assert block.ref_cnt == 1
# STEP (2): Send Finished to PB.
# (2a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_id]
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_short_prompt_lifecycle():
"""Test lifecycle of a Remote Decode request with short prompt."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Not enough tokens for full block.
NUM_TOKENS = vllm_config.cache_config.block_size // 2
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
# Since tokens < block_size, there will be no kv xfer.
# So this should be cleaned up immediately.
_ = scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
# We need one more call to schedule() to clear data for persistent batch.
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still works with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS)
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco.outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS)
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_remote.request_id]
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
create_request, create_scheduler, create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(engine_core_outputs.outputs) == 0
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
outputs = engine_core_outputs.outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
def test_interleaved_lifecycle():
"""Test Remote Prefills Work Well With Other Requests."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_local_a = create_request(
request_id=2,
num_tokens=NUM_TOKENS,
)
request_local_b = create_request(
request_id=3,
num_tokens=NUM_TOKENS,
)
# STEP 1: Regular request is running.
scheduler.add_request(request_local_a)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
model_runner_output = create_model_runner_output([request_local_a])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 2: Add a local and remote request.
scheduler.add_request(request_local_b)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 1
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 3: continue running, KVs not arrived yet.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
# STEP 4: KVs arrive.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b],
finished_recving=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 5: RECVed KVs are sent to ModelRunner.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 6: Hit EOS and free.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote],
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
request_local = create_request(
request_id=2,
num_tokens=NUM_TOKENS,
do_remote_prefill=False,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
# Schedule the local prefill request. This should
# cause blocks to be cached, but separately from
scheduler.add_request(request_local)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_local.request_id]
remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501
request_remote.request_id]
# Local should have cached blocks (but not all due to preallocate).
num_hashed_blocks = 0
for block in local_blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks > 0
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
outputs = engine_core_outputs.outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVTransferParams)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 64,
block_size: int = 16,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
) -> Request:
"""Make dummy request for testing."""
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = NixlKVTransferParams(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234)
else:
kv_transfer_params = None
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_inputs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
finished_sending=finished_sending,
finished_recving=finished_recving,
)

View File

@ -8,6 +8,7 @@ import inspect
import json
import re
import textwrap
import uuid
import warnings
from collections import Counter
from contextlib import contextmanager
@ -3438,6 +3439,9 @@ class KVTransferConfig:
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
"""
engine_id: str = str(uuid.uuid4())
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
@ -3448,7 +3452,7 @@ class KVTransferConfig:
kv_role: Optional[KVRole] = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'both'."""
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
kv_rank: Optional[int] = None
"""The rank of this vLLM instance in the KV cache transfer. Typical value:

View File

@ -105,3 +105,8 @@ KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1")
KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"NixlConnector")

View File

@ -1,8 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
]
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]

View File

@ -23,7 +23,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -34,6 +34,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVTransferParams:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["KVTransferParams"]:
return None
@dataclass
class KVConnectorMetadata:
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
"""
pass
class KVConnectorBase_V1(ABC):
_KVTransferParams = KVTransferParams
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
@ -66,6 +89,10 @@ class KVConnectorBase_V1(ABC):
def role(self) -> KVConnectorRole:
return self._role
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
@ -97,9 +124,15 @@ class KVConnectorBase_V1(ABC):
"""
return self._connector_metadata
# ==============================
# Worker-side methods
# ==============================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
@ -162,15 +195,37 @@ class KVConnectorBase_V1(ABC):
"""
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous (recving, sending).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
# ==============================
# Scheduler-side methods
# ==============================
def set_kv_transfer_params(self, request: "Request"):
"""Parse raw KV Transfer params."""
assert request.kv_transfer_params is None
kv_transfer_params = self._KVTransferParams.from_raw_dict(
request.raw_kv_transfer_params)
request.kv_transfer_params = kv_transfer_params
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
@ -181,13 +236,16 @@ class KVConnectorBase_V1(ABC):
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
pass
@abstractmethod
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
@ -207,3 +265,20 @@ class KVConnectorBase_V1(ABC):
scheduler_output (SchedulerOutput): the scheduler output object.
"""
pass
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return False, None

View File

@ -13,6 +13,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -92,7 +93,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> int:
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
@ -107,9 +108,10 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
request, num_computed_tokens), False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.

View File

@ -0,0 +1,805 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import math
import threading
import time
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator
import msgspec
import torch
import zmq
from typing_extensions import Optional
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.utils import round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
@dataclass
class NixlKVTransferParams(KVTransferParams):
def __init__(
self,
do_remote_prefill: bool,
do_remote_decode: bool,
remote_block_ids: Optional[list[int]] = None,
remote_host: Optional[str] = None,
remote_port: Optional[int] = None,
remote_engine_id: Optional[str] = None,
):
self.do_remote_prefill = do_remote_prefill
self.do_remote_decode = do_remote_decode
self.remote_block_ids = remote_block_ids
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_engine_id = remote_engine_id
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["NixlKVTransferParams"]:
# If no raw transfer params passed, return None.
if raw_dict is None:
return None
# Validate the request is formatted properly.
if (("do_remote_prefill" not in raw_dict)
or ("do_remote_decode" not in raw_dict)
or ("remote_block_ids" not in raw_dict)
or ("remote_host" not in raw_dict)
or ("remote_port" not in raw_dict)
or ("remote_engine_id" not in raw_dict)):
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", raw_dict)
return None
return NixlKVTransferParams(
do_remote_prefill=raw_dict["do_remote_prefill"],
do_remote_decode=raw_dict["do_remote_decode"],
remote_block_ids=raw_dict["remote_block_ids"],
remote_host=raw_dict["remote_host"],
remote_port=raw_dict["remote_port"],
remote_engine_id=raw_dict["remote_engine_id"],
)
class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: str
class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
def add_new_req(
self,
request_id: str,
local_block_ids: list[int],
kv_transfer_params: NixlKVTransferParams,
):
assert request_id not in self.requests
assert kv_transfer_params.remote_block_ids is not None
assert kv_transfer_params.remote_engine_id is not None
assert kv_transfer_params.remote_host is not None
assert kv_transfer_params.remote_port is not None
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params.remote_block_ids,
remote_engine_id=kv_transfer_params.remote_engine_id,
remote_host=kv_transfer_params.remote_host,
remote_port=kv_transfer_params.remote_port,
)
class NixlConnector(KVConnectorBase_V1):
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
NixlConnectorScheduler(vllm_config, str(self.engine_id))
self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""NixlConnector does not save explicitly."""
pass
def wait_for_save(self):
"""NixlConnector does not save explicitly."""
pass
class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
# No KVTransfer for this request.
if request.kv_transfer_params is None:
return 0, False
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
# Remote prefill: get all prompt blocks from remote.
if request.kv_transfer_params.do_remote_prefill:
assert num_computed_tokens % self.block_size == 0
rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
if request.kv_transfer_params is None:
return
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if request.kv_transfer_params.do_remote_prefill:
# NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks.
if request.kv_transfer_params.remote_block_ids:
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
request.kv_transfer_params.do_remote_prefill = False
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = NixlConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if request.kv_transfer_params is None:
return False, None
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if ((not request.kv_transfer_params.do_remote_decode)
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
return False, None
# Get computed blocks.
all_full = request.num_computed_tokens % self.block_size == 0
computed_block_ids = (block_ids if all_full else block_ids[:-1])
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0
return delay_free_blocks, NixlKVTransferParams(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
).__dict__
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
logger.info("Initializing NIXL worker %s", engine_id)
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}
# Metadata.
self.engine_id = engine_id
self.rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
# KV Caches and nixl tracking data.
self.kv_caches: dict[str, torch.Tensor] = {}
# Map of engine_id -> kv_caches_base_addr
self.kv_caches_base_addr: dict[str, list[int]] = {}
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0
# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[str, int] = {}
# Map of engine_id -> num_blocks.
self.dst_num_blocks: dict[str, int] = {}
self._registered_descs: list[Any] = []
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
list[Any])
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self._done_recving_count: defaultdict[str,
int] = defaultdict(lambda: 0)
self._done_sending_count: defaultdict[str,
int] = defaultdict(lambda: 0)
# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
# NOTE(rob): to support heterogeneous TP, we will have to
# move this into the scheduler rather than worker, since
# each rank needs the metadata of all other ranks (whereas
# in this setup, each rank only gets one other rank's meta.
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded NixlAgentMetadata: %s bytes",
str(size_in_bytes))
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
path = f"tcp://{host}:{port}"
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
while True:
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning(
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int):
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = f"tcp://{host}:{port + self.rank}"
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
# Register Remote agent.
self.add_remote_agent(metadata)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
use_mla = len(first_kv_cache.shape) == 3
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
else:
# [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
self.block_len = kv_elem_size * math.prod(block_shape)
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
first_kv_cache.shape)
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
caches_data.append((base_addr, region_len, self.rank, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
ready_event.wait()
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
engine_id = nixl_agent_meta.engine_id
if engine_id in self._remote_agents:
return
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
# Create src descs and xfer side handles.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# Create dst descs and xfer side handles.
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
blocks_data = []
for base_addr in self.kv_caches_base_addr[engine_id]:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s and rank %s",
len(blocks_data), engine_id, self.rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id], descs)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving.
In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.rank, len(done_sending),
len(done_recving))
if self.world_size == 1:
return done_sending, done_recving
# Rank 0: get finished from all other ranks.
if self.rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1
# Return ids that finished on all ranks to the scheduler.
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
del self._done_recving_count[req_id]
all_done_recving.add(req_id)
all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)
return all_done_sending, all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)
# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
"""Get req_ids which got a remote xfer message."""
notified_req_ids: set[str] = set()
for req_ids in self.nixl_wrapper.get_new_notifs().values():
for req_id in req_ids:
assert req_id not in notified_req_ids
notified_req_ids.add(req_id.decode("utf-8"))
return notified_req_ids
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
transfers: dict of req_id -> list[running_xfer]
Returns:
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
running_reqs = []
for handle in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# TODO ptarasiewicz: why abort is throwing errors?
# self.nixl_wrapper.release_xfer_handle(handle)
continue
if xfer_state == "PROC":
running_reqs.append(handle)
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
if len(running_reqs) == 0:
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = running_reqs
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
"""
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.requests.items():
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_host=meta.remote_host,
remote_port=meta.remote_port,
)
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_host: str,
remote_port: int,
dst_engine_id: str,
request_id: str,
):
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
if dst_engine_id not in self._remote_agents:
self._nixl_handshake(remote_host, remote_port)
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
# read trxn async easily). If we want to make "READ" happen cleanly,
# then we will need to have the staging blocks on the remote side.
# NOTE(rob): according to nvidia the staging blocks are used to
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
self.nixl_wrapper.send_notif(dst_engine_id,
notif_msg=request_id.encode("utf-8"))
return
# Partial prefix cache hit: just read uncomputed blocks.
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# Get descs ids.
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=request_id.encode("utf-8"),
)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)
def _get_block_descs_ids(self, engine_id: str,
block_ids: list[int]) -> list[int]:
"""Get the descs ids for a set of block ids."""
# range(1) for MLA, range(2) otherwise.
region_ids = range(self.num_regions)
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.
descs_ids: list[int] = []
for reg_id in region_ids:
for block_id in block_ids:
descs_ids.append(reg_id * num_blocks + block_id)
return descs_ids
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx: Optional[zmq.Context] = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
if socket_type == zmq.ROUTER:
socket = ctx.socket(zmq.ROUTER)
socket.bind(addr)
elif socket_type == zmq.REQ:
socket = ctx.socket(zmq.REQ)
socket.connect(addr)
else:
raise ValueError(f"Unexpected socket type: {socket_type}")
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)

View File

@ -17,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -132,8 +133,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
@ -225,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> int:
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
@ -239,7 +239,6 @@ class SharedStorageConnector(KVConnectorBase_V1):
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
@ -248,7 +247,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0
return 0, False
logger.info("External Cache Hit!")
@ -257,9 +256,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.

View File

@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
# doc: end-chat-completion-extra-params
@ -540,7 +543,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
logit_bias=self.logit_bias,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
# doc: end-completion-extra-params
# Default sampling parameters for completion requests
@ -973,7 +982,9 @@ class CompletionRequest(OpenAIBaseModel):
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
@model_validator(mode="before")
@classmethod
@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel):
model: str
choices: list[CompletionResponseChoice]
usage: UsageInfo
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
class CompletionResponseStreamChoice(OpenAIBaseModel):
@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel):
choices: list[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
class DeltaMessage(OpenAIBaseModel):

View File

@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing):
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
return response

View File

@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing):
model=model_name,
choices=choices,
usage=usage,
)
kv_transfer_params=final_res_batch[0].kv_transfer_params)
def _create_completion_logprobs(
self,

View File

@ -112,6 +112,8 @@ if TYPE_CHECKING:
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
def get_default_cache_root():
@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# insecure method and it is needed for some reason.
"VLLM_ALLOW_INSECURE_SERIALIZATION":
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
# IP address used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_HOST":
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
# Port used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_PORT":
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
}
# end-env-vars-definition

View File

@ -11,10 +11,6 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger
if TYPE_CHECKING:
@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer = (attn_metadata is not None
and has_kv_transfer_group()
and is_v1_kv_transfer_group())
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
try:
yield
finally:
@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any,
"(batchsize, count, median_time(ms)): %s"),
forward_stats)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if trigger_kv_transfer:
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.wait_for_save()
_forward_context = prev_context

View File

@ -4,7 +4,7 @@ import time
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass
from typing import Generic, Optional, Union
from typing import Any, Generic, Optional, Union
import torch
from typing_extensions import TypeVar, deprecated
@ -103,6 +103,7 @@ class RequestOutput:
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""
def __init__(
@ -120,6 +121,7 @@ class RequestOutput:
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
@ -133,11 +135,13 @@ class RequestOutput:
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params
for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs):

View File

@ -36,6 +36,12 @@ class KVCacheBlocks:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
return [
block.block_id for block in self.blocks if block.block_hash is None
]
class KVCacheManager:
@ -116,6 +122,12 @@ class KVCacheManager:
- The number of computed tokens.
"""
# Request already has blocks from async load via KVConnector.
num_existing_blocks = len(
self.single_type_manager.req_to_blocks[request.request_id])
if num_existing_blocks > 0:
return KVCacheBlocks.create_empty(), request.num_computed_tokens
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
@ -173,6 +185,7 @@ class KVCacheManager:
num_new_tokens: int,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
@ -186,6 +199,9 @@ class KVCacheManager:
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout:
```
@ -255,7 +271,9 @@ class KVCacheManager:
new_blocks = self.single_type_manager.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
if not self.enable_caching:
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks)
# Speculated tokens might be rejected in the future, so we does
@ -350,3 +368,16 @@ class KVCacheManager:
A list of KV cache events.
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return [
block.block_id
for block in self.single_type_manager.req_to_blocks[request_id]
]
def get_num_blocks(self, request_id: str):
"""Get the number of blocks."""
assert request_id in self.single_type_manager.req_to_blocks
return len(self.single_type_manager.req_to_blocks[request_id])

View File

@ -4,6 +4,7 @@ from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
@ -137,3 +138,6 @@ class SchedulerInterface(ABC):
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None

View File

@ -5,13 +5,15 @@ from __future__ import annotations
import time
from collections import defaultdict, deque
from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole,
KVTransferParams)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface):
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# P/D: requests in process of recving KV transfers
self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0]
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
@ -330,26 +345,33 @@ class Scheduler(SchedulerInterface):
continue
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
new_computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
num_external_tokens = (
0 if self.connector is None else
num_external_tokens, load_kv_async = (
(0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens += num_external_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
@ -358,21 +380,20 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
new_encoder_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_tokens,
computed_blocks,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
)
if new_blocks is None:
# The request cannot be scheduled.
@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_tokens,
)
self.waiting.popleft()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
computed_blocks + new_blocks).get_block_ids()
self.kv_cache_manager.get_block_ids(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface):
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface):
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
kv_transfer_params = self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
if new_token_ids or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface):
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface):
if not stopped:
new_running.append(request)
# P/D: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface):
request.status = finished_status
self._free_request(request)
def _free_request(self, request: Request) -> None:
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
if not delay_free_blocks:
self._free_blocks(request)
return kv_xfer_params
def _free_blocks(self, request: Request):
assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface):
def shutdown(self) -> None:
if self.kv_event_publisher:
self.kv_event_publisher.shutdown()
########################################################################
# P/D Related Methods
########################################################################
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
return self.connector
def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
"""Invoke the KV connector request_finished() method if applicable."""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
P/D: check if the request_id is finished_recving.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
if request.request_id not in self.finished_recving_kv_req_ids:
return False
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.single_type_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput):
"""
P/D: update the scheduler state based on the output.
The Worker side connectors add finished_recving and
finished_sending reqs to the output.
* if finished_sending: free the blocks
# if finished_recving: add to state so we can
scheduler the request during the next step.
"""
# P/D: update recv and send status from last step.
for req_id in (model_runner_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])

View File

@ -105,6 +105,7 @@ class EngineCoreOutput(
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
@property
def finished(self) -> bool:

View File

@ -182,6 +182,15 @@ class EngineCore:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None:
if (kv_connector := self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector.
kv_connector.set_kv_transfer_params(req)
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]):

View File

@ -3,7 +3,7 @@
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
@ -146,6 +146,7 @@ class RequestState:
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> Optional[RequestOutput]:
finished = finish_reason is not None
@ -167,13 +168,15 @@ class RequestState:
if not outputs:
return None
return self._new_request_output(request_id, outputs, finished)
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA:
@ -189,6 +192,7 @@ class RequestState:
prompt_logprobs=prompt_logprobs,
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
)
def _new_completion_output(
@ -335,6 +339,7 @@ class OutputProcessor:
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.is_prefilling = False
@ -350,7 +355,8 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason):
new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)

View File

@ -100,12 +100,16 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
finished_sending=None,
finished_recving=None)

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
@ -61,6 +62,15 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed).
raw_params = (None if sampling_params.extra_args is None
else sampling_params.extra_args.get(
"kv_transfer_params", None))
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:
@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import gc
import time
import weakref
@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
@ -1065,16 +1067,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = output
hidden_states, aux_hidden_states = model_output
else:
hidden_states = output
hidden_states = model_output
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec,