mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[P/D] NIXL Integration (#17751)
Signed-off-by: ApostaC <yihua98@uchicago.edu> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Brent Salisbury <bsalisbu@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Brent Salisbury <bsalisbu@redhat.com>
This commit is contained in:
parent
05a4324f8e
commit
d19110204c
@ -214,6 +214,7 @@ steps:
|
|||||||
- pytest -v -s v1/worker
|
- pytest -v -s v1/worker
|
||||||
- pytest -v -s v1/structured_output
|
- pytest -v -s v1/structured_output
|
||||||
- pytest -v -s v1/spec_decode
|
- pytest -v -s v1/spec_decode
|
||||||
|
- pytest -v -s v1/kv_connector/unit
|
||||||
- pytest -v -s v1/test_serial_utils.py
|
- pytest -v -s v1/test_serial_utils.py
|
||||||
- pytest -v -s v1/test_stats.py
|
- pytest -v -s v1/test_stats.py
|
||||||
- pytest -v -s v1/test_utils.py
|
- pytest -v -s v1/test_utils.py
|
||||||
|
|||||||
@ -870,7 +870,7 @@ def test_kv_connector_basic():
|
|||||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
||||||
NUM_MATCHED_NEW_TOKENS)
|
NUM_MATCHED_NEW_TOKENS, False)
|
||||||
|
|
||||||
######################################################
|
######################################################
|
||||||
# FIRST SET OF REQUESTS - External Hit Only
|
# FIRST SET OF REQUESTS - External Hit Only
|
||||||
@ -981,7 +981,7 @@ def test_kv_connector_unable_to_allocate():
|
|||||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
||||||
NUM_MATCHED_NEW_TOKENS)
|
NUM_MATCHED_NEW_TOKENS, False)
|
||||||
|
|
||||||
# Create two requests. The second request will not be able to
|
# Create two requests. The second request will not be able to
|
||||||
# allocate slots because it will not have enough blocks.
|
# allocate slots because it will not have enough blocks.
|
||||||
@ -1060,7 +1060,7 @@ def test_kv_connector_handles_preemption():
|
|||||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
|
||||||
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
|
||||||
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
scheduler.connector.get_num_new_matched_tokens.return_value = (
|
||||||
NUM_MATCHED_NEW_TOKENS)
|
NUM_MATCHED_NEW_TOKENS, False)
|
||||||
|
|
||||||
# Create two requests.
|
# Create two requests.
|
||||||
# Both can be scheduled at first, but the second request
|
# Both can be scheduled at first, but the second request
|
||||||
|
|||||||
171
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Executable file
171
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Executable file
@ -0,0 +1,171 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -xe
|
||||||
|
|
||||||
|
# Models to run
|
||||||
|
MODELS=(
|
||||||
|
"Qwen/Qwen3-0.6B"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Number of prefill and decode instances to create
|
||||||
|
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
||||||
|
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
|
||||||
|
|
||||||
|
# Find the git repository root directory
|
||||||
|
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||||
|
|
||||||
|
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||||
|
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||||
|
|
||||||
|
# Waits for vLLM to start.
|
||||||
|
wait_for_server() {
|
||||||
|
local port=$1
|
||||||
|
timeout 1200 bash -c "
|
||||||
|
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done" && return 0 || return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to clean up previous instances
|
||||||
|
cleanup_instances() {
|
||||||
|
echo "Cleaning up any running vLLM instances..."
|
||||||
|
pkill -f "vllm serve" || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle to get model-specific arguments for deepseek
|
||||||
|
get_model_args() {
|
||||||
|
local model_name=$1
|
||||||
|
local extra_args=""
|
||||||
|
|
||||||
|
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
|
||||||
|
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$extra_args"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Function to run tests for a specific model
|
||||||
|
run_tests_for_model() {
|
||||||
|
local model_name=$1
|
||||||
|
echo "================================"
|
||||||
|
echo "Testing model: $model_name"
|
||||||
|
echo "================================"
|
||||||
|
|
||||||
|
# Get model-specific arguments
|
||||||
|
local model_args=$(get_model_args "$model_name")
|
||||||
|
|
||||||
|
# Arrays to store all hosts and ports
|
||||||
|
PREFILL_HOSTS=()
|
||||||
|
PREFILL_PORTS=()
|
||||||
|
DECODE_HOSTS=()
|
||||||
|
DECODE_PORTS=()
|
||||||
|
|
||||||
|
# Start prefill instances
|
||||||
|
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
|
||||||
|
# Calculate GPU ID - we'll distribute across available GPUs
|
||||||
|
GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)))
|
||||||
|
# Calculate port number (base port + instance number)
|
||||||
|
PORT=$((8100 + i))
|
||||||
|
# Calculate side channel port
|
||||||
|
SIDE_CHANNEL_PORT=$((5559 + i))
|
||||||
|
|
||||||
|
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||||
|
|
||||||
|
# Build the command with or without model-specific args
|
||||||
|
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
|
||||||
|
--port $PORT \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--gpu-memory-utilization 0.2 \
|
||||||
|
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
|
||||||
|
|
||||||
|
if [ -n "$model_args" ]; then
|
||||||
|
FULL_CMD="$BASE_CMD $model_args"
|
||||||
|
else
|
||||||
|
FULL_CMD="$BASE_CMD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
eval "$FULL_CMD &"
|
||||||
|
|
||||||
|
# Store host and port for proxy configuration
|
||||||
|
PREFILL_HOSTS+=("localhost")
|
||||||
|
PREFILL_PORTS+=($PORT)
|
||||||
|
done
|
||||||
|
|
||||||
|
# Start decode instances
|
||||||
|
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
|
||||||
|
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
|
||||||
|
GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)))
|
||||||
|
# Calculate port number (base port + instance number)
|
||||||
|
PORT=$((8200 + i))
|
||||||
|
# Calculate side channel port
|
||||||
|
SIDE_CHANNEL_PORT=$((5659 + i))
|
||||||
|
|
||||||
|
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||||
|
|
||||||
|
# Build the command with or without model-specific args
|
||||||
|
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \
|
||||||
|
--port $PORT \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--gpu-memory-utilization 0.2 \
|
||||||
|
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
|
||||||
|
|
||||||
|
if [ -n "$model_args" ]; then
|
||||||
|
FULL_CMD="$BASE_CMD $model_args"
|
||||||
|
else
|
||||||
|
FULL_CMD="$BASE_CMD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
eval "$FULL_CMD &"
|
||||||
|
|
||||||
|
# Store host and port for proxy configuration
|
||||||
|
DECODE_HOSTS+=("localhost")
|
||||||
|
DECODE_PORTS+=($PORT)
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all instances to start
|
||||||
|
for PORT in "${PREFILL_PORTS[@]}"; do
|
||||||
|
echo "Waiting for prefill instance on port $PORT to start..."
|
||||||
|
wait_for_server $PORT
|
||||||
|
done
|
||||||
|
|
||||||
|
for PORT in "${DECODE_PORTS[@]}"; do
|
||||||
|
echo "Waiting for decode instance on port $PORT to start..."
|
||||||
|
wait_for_server $PORT
|
||||||
|
done
|
||||||
|
|
||||||
|
# Build the command for the proxy server with all the hosts and ports
|
||||||
|
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
|
||||||
|
|
||||||
|
# Add all prefill hosts and ports
|
||||||
|
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
|
||||||
|
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}"
|
||||||
|
|
||||||
|
# Add all decode hosts and ports
|
||||||
|
PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}"
|
||||||
|
PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}"
|
||||||
|
|
||||||
|
# Start the proxy server
|
||||||
|
echo "Starting proxy server with command: $PROXY_CMD"
|
||||||
|
$PROXY_CMD &
|
||||||
|
|
||||||
|
# Wait for the proxy to start
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
# Run lm eval for this model
|
||||||
|
echo "Running tests for $model_name"
|
||||||
|
TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
|
||||||
|
|
||||||
|
# Clean up before running next model
|
||||||
|
cleanup_instances
|
||||||
|
sleep 3
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run tests for each model
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
run_tests_for_model "$model"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "All tests completed!"
|
||||||
123
tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh
Normal file
123
tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -xe
|
||||||
|
|
||||||
|
# Models to run
|
||||||
|
MODELS=(
|
||||||
|
"Qwen/Qwen3-0.6B"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the git repository root directory
|
||||||
|
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||||
|
|
||||||
|
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||||
|
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||||
|
|
||||||
|
# Waits for vLLM to start.
|
||||||
|
wait_for_server() {
|
||||||
|
local port=$1
|
||||||
|
timeout 1200 bash -c "
|
||||||
|
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done" && return 0 || return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to clean up previous instances
|
||||||
|
cleanup_instances() {
|
||||||
|
echo "Cleaning up any running vLLM instances..."
|
||||||
|
pkill -f "vllm serve" || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle to get model-specific arguments for deepseek
|
||||||
|
get_model_args() {
|
||||||
|
local model_name=$1
|
||||||
|
local extra_args=""
|
||||||
|
|
||||||
|
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
|
||||||
|
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "$extra_args"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Function to run tests for a specific model
|
||||||
|
run_tests_for_model() {
|
||||||
|
local model_name=$1
|
||||||
|
echo "================================"
|
||||||
|
echo "Testing model: $model_name"
|
||||||
|
echo "================================"
|
||||||
|
|
||||||
|
# Get model-specific arguments
|
||||||
|
local model_args=$(get_model_args "$model_name")
|
||||||
|
|
||||||
|
# Start prefill instance
|
||||||
|
PREFILL_PORT=8001
|
||||||
|
|
||||||
|
BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
|
||||||
|
--port $PREFILL_PORT \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--gpu-memory-utilization 0.2 \
|
||||||
|
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
|
||||||
|
|
||||||
|
if [ -n "$model_args" ]; then
|
||||||
|
FULL_CMD="$BASE_CMD $model_args"
|
||||||
|
else
|
||||||
|
FULL_CMD="$BASE_CMD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
eval "$FULL_CMD &"
|
||||||
|
|
||||||
|
# Start decode instance
|
||||||
|
DECODE_PORT=8002
|
||||||
|
|
||||||
|
# Build the command with or without model-specific args
|
||||||
|
BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
|
||||||
|
--port $DECODE_PORT \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--gpu-memory-utilization 0.2 \
|
||||||
|
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
|
||||||
|
|
||||||
|
if [ -n "$model_args" ]; then
|
||||||
|
FULL_CMD="$BASE_CMD $model_args"
|
||||||
|
else
|
||||||
|
FULL_CMD="$BASE_CMD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
eval "$FULL_CMD &"
|
||||||
|
|
||||||
|
# Wait for all instances to start
|
||||||
|
echo "Waiting for prefill instance on port $PORT to start..."
|
||||||
|
wait_for_server $PREFILL_PORT
|
||||||
|
echo "Waiting for decode instance on port $PORT to start..."
|
||||||
|
wait_for_server $DECODE_PORT
|
||||||
|
|
||||||
|
# Build the command for the proxy server with all the hosts and ports
|
||||||
|
PROXY_PORT=8192
|
||||||
|
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT"
|
||||||
|
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
|
||||||
|
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
|
||||||
|
# Start the proxy server
|
||||||
|
echo "Starting proxy server with command: $PROXY_CMD"
|
||||||
|
$PROXY_CMD &
|
||||||
|
|
||||||
|
# Wait for the proxy to start
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
# Run lm eval for this model
|
||||||
|
echo "Running tests for $model_name"
|
||||||
|
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
||||||
|
|
||||||
|
# Clean up before running next model
|
||||||
|
cleanup_instances
|
||||||
|
sleep 3
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run tests for each model
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
run_tests_for_model "$model"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "All tests completed!"
|
||||||
60
tests/v1/kv_connector/nixl_integration/test_accuracy.py
Normal file
60
tests/v1/kv_connector/nixl_integration/test_accuracy.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import os
|
||||||
|
|
||||||
|
import lm_eval
|
||||||
|
import openai
|
||||||
|
|
||||||
|
BASE_URL = "http://localhost:8192/v1"
|
||||||
|
NUM_CONCURRENT = 100
|
||||||
|
TASK = "gsm8k"
|
||||||
|
FILTER = "exact_match,strict-match"
|
||||||
|
RTOL = 0.03
|
||||||
|
|
||||||
|
# Model-specific expected values
|
||||||
|
EXPECTED_VALUES = {
|
||||||
|
"Qwen/Qwen3-0.6B": 0.41,
|
||||||
|
}
|
||||||
|
|
||||||
|
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
|
||||||
|
|
||||||
|
# Get model name from environment variable
|
||||||
|
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
|
||||||
|
|
||||||
|
|
||||||
|
def run_simple_prompt():
|
||||||
|
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
|
||||||
|
completion = client.completions.create(model=MODEL_NAME,
|
||||||
|
prompt=SIMPLE_PROMPT)
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
print(f"Completion results for {MODEL_NAME}:")
|
||||||
|
print(completion)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
def test_accuracy():
|
||||||
|
"""Run the end to end accuracy test."""
|
||||||
|
run_simple_prompt()
|
||||||
|
|
||||||
|
model_args = (f"model={MODEL_NAME},"
|
||||||
|
f"base_url={BASE_URL}/completions,"
|
||||||
|
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||||
|
|
||||||
|
results = lm_eval.simple_evaluate(
|
||||||
|
model="local-completions",
|
||||||
|
model_args=model_args,
|
||||||
|
tasks=TASK,
|
||||||
|
)
|
||||||
|
|
||||||
|
measured_value = results["results"][TASK][FILTER]
|
||||||
|
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
|
||||||
|
|
||||||
|
if expected_value is None:
|
||||||
|
print(f"Warning: No expected value found for {MODEL_NAME}. "
|
||||||
|
"Skipping accuracy check.")
|
||||||
|
print(f"Measured value: {measured_value}")
|
||||||
|
return
|
||||||
|
|
||||||
|
assert (measured_value - RTOL < expected_value
|
||||||
|
and measured_value + RTOL > expected_value
|
||||||
|
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||||
77
tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Normal file
77
tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import os
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
|
||||||
|
DECODE_PORT = os.getenv("DECODE_PORT", None)
|
||||||
|
PROXY_PORT = os.getenv("PROXY_PORT", None)
|
||||||
|
|
||||||
|
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
|
||||||
|
|
||||||
|
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
|
||||||
|
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
|
||||||
|
SHORT_PROMPT = "Red Hat is "
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_cases():
|
||||||
|
# Set the OpenAI API key and base URL
|
||||||
|
decode_client = openai.OpenAI(
|
||||||
|
api_key="MY_KEY",
|
||||||
|
base_url=f"http://localhost:{DECODE_PORT}/v1",
|
||||||
|
)
|
||||||
|
prefill_client = openai.OpenAI(
|
||||||
|
api_key="MY_KEY",
|
||||||
|
base_url=f"http://localhost:{PREFILL_PORT}/v1",
|
||||||
|
)
|
||||||
|
proxy_client = openai.OpenAI(
|
||||||
|
api_key="MY_KEY",
|
||||||
|
base_url=f"http://localhost:{PROXY_PORT}/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the list of models
|
||||||
|
models = decode_client.models.list()
|
||||||
|
MODEL = models.data[0].id
|
||||||
|
|
||||||
|
# (1) Check that we can handle a very short prompt,
|
||||||
|
# less than the length of the block size.
|
||||||
|
completion = proxy_client.completions.create(model=MODEL,
|
||||||
|
prompt=SHORT_PROMPT,
|
||||||
|
temperature=0)
|
||||||
|
proxy_response = completion.choices[0].text
|
||||||
|
completion = prefill_client.completions.create(model=MODEL,
|
||||||
|
prompt=SHORT_PROMPT,
|
||||||
|
temperature=0)
|
||||||
|
prefill_response = completion.choices[0].text
|
||||||
|
print(f"SMALL PROMPT: {proxy_response=}")
|
||||||
|
assert proxy_response == prefill_response
|
||||||
|
|
||||||
|
# (2) Check that we can handle a full prefix cache
|
||||||
|
# hit on the D worker but not on the P worker.
|
||||||
|
# (2a): prime the D worker.
|
||||||
|
completion = decode_client.completions.create(model=MODEL,
|
||||||
|
prompt=PROMPT,
|
||||||
|
temperature=0)
|
||||||
|
decode_response = completion.choices[0].text
|
||||||
|
# (2b): send via the P/D setup
|
||||||
|
completion = proxy_client.completions.create(model=MODEL,
|
||||||
|
prompt=PROMPT,
|
||||||
|
temperature=0)
|
||||||
|
proxy_response = completion.choices[0].text
|
||||||
|
print(f"FULL CACHE HIT: {proxy_response=}")
|
||||||
|
assert proxy_response == decode_response
|
||||||
|
|
||||||
|
# (3) Check that we can handle a partial prefix cache
|
||||||
|
# hit on the D worker.
|
||||||
|
completion = proxy_client.completions.create(model=MODEL,
|
||||||
|
prompt=LONG_PROMPT,
|
||||||
|
temperature=0)
|
||||||
|
proxy_response = completion.choices[0].text
|
||||||
|
completion = prefill_client.completions.create(model=MODEL,
|
||||||
|
prompt=LONG_PROMPT,
|
||||||
|
temperature=0)
|
||||||
|
prefill_response = completion.choices[0].text
|
||||||
|
print(f"PARTIAL CACHE HIT: {proxy_response=}")
|
||||||
|
assert proxy_response == prefill_response
|
||||||
260
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Normal file
260
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
Lifespan context manager to handle startup and shutdown events.
|
||||||
|
"""
|
||||||
|
# Startup: Initialize client pools for prefiller and decoder services
|
||||||
|
app.state.prefill_clients = []
|
||||||
|
app.state.decode_clients = []
|
||||||
|
|
||||||
|
# Create prefill clients
|
||||||
|
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
||||||
|
prefiller_base_url = f'http://{host}:{port}/v1'
|
||||||
|
app.state.prefill_clients.append({
|
||||||
|
'client':
|
||||||
|
httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
|
||||||
|
'host':
|
||||||
|
host,
|
||||||
|
'port':
|
||||||
|
port,
|
||||||
|
'id':
|
||||||
|
i
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create decode clients
|
||||||
|
for i, (host, port) in enumerate(global_args.decoder_instances):
|
||||||
|
decoder_base_url = f'http://{host}:{port}/v1'
|
||||||
|
app.state.decode_clients.append({
|
||||||
|
'client':
|
||||||
|
httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
|
||||||
|
'host':
|
||||||
|
host,
|
||||||
|
'port':
|
||||||
|
port,
|
||||||
|
'id':
|
||||||
|
i
|
||||||
|
})
|
||||||
|
|
||||||
|
# Initialize round-robin iterators
|
||||||
|
app.state.prefill_iterator = itertools.cycle(
|
||||||
|
range(len(app.state.prefill_clients)))
|
||||||
|
app.state.decode_iterator = itertools.cycle(
|
||||||
|
range(len(app.state.decode_clients)))
|
||||||
|
|
||||||
|
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
|
||||||
|
f"and {len(app.state.decode_clients)} decode clients.")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown: Close all clients
|
||||||
|
for client_info in app.state.prefill_clients:
|
||||||
|
await client_info['client'].aclose()
|
||||||
|
|
||||||
|
for client_info in app.state.decode_clients:
|
||||||
|
await client_info['client'].aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# Update FastAPI app initialization to use lifespan
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
|
||||||
|
# For prefiller instances
|
||||||
|
parser.add_argument("--prefiller-hosts",
|
||||||
|
"--prefiller-host",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=["localhost"])
|
||||||
|
parser.add_argument("--prefiller-ports",
|
||||||
|
"--prefiller-port",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[8100])
|
||||||
|
|
||||||
|
# For decoder instances
|
||||||
|
parser.add_argument("--decoder-hosts",
|
||||||
|
"--decoder-host",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=["localhost"])
|
||||||
|
parser.add_argument("--decoder-ports",
|
||||||
|
"--decoder-port",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[8200])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate and pair hosts with ports
|
||||||
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of prefiller hosts must match number of prefiller ports")
|
||||||
|
|
||||||
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of decoder hosts must match number of decoder ports")
|
||||||
|
|
||||||
|
# Create tuples of (host, port) for each service type
|
||||||
|
args.prefiller_instances = list(
|
||||||
|
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||||
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_client(app, service_type: str):
|
||||||
|
"""
|
||||||
|
Get the next client in round-robin fashion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app instance
|
||||||
|
service_type: Either 'prefill' or 'decode'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next client to use
|
||||||
|
"""
|
||||||
|
if service_type == 'prefill':
|
||||||
|
client_idx = next(app.state.prefill_iterator)
|
||||||
|
return app.state.prefill_clients[client_idx]
|
||||||
|
elif service_type == 'decode':
|
||||||
|
client_idx = next(app.state.decode_iterator)
|
||||||
|
return app.state.decode_clients[client_idx]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown service type: {service_type}")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request_to_service(client_info: dict, endpoint: str,
|
||||||
|
req_data: dict, request_id: str):
|
||||||
|
"""
|
||||||
|
Send a request to a service using a client from the pool.
|
||||||
|
"""
|
||||||
|
req_data = req_data.copy()
|
||||||
|
req_data['kv_transfer_params'] = {
|
||||||
|
"do_remote_decode": True,
|
||||||
|
"do_remote_prefill": False,
|
||||||
|
"remote_engine_id": None,
|
||||||
|
"remote_block_ids": None,
|
||||||
|
"remote_host": None,
|
||||||
|
"remote_port": None
|
||||||
|
}
|
||||||
|
req_data["stream"] = False
|
||||||
|
req_data["max_tokens"] = 1
|
||||||
|
if "stream_options" in req_data:
|
||||||
|
del req_data["stream_options"]
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
"X-Request-Id": request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client_info['client'].post(endpoint,
|
||||||
|
json=req_data,
|
||||||
|
headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_service_response(client_info: dict, endpoint: str,
|
||||||
|
req_data: dict, request_id: str):
|
||||||
|
"""
|
||||||
|
Asynchronously stream response from a service using a client from the pool.
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
"X-Request-Id": request_id
|
||||||
|
}
|
||||||
|
|
||||||
|
async with client_info['client'].stream("POST",
|
||||||
|
endpoint,
|
||||||
|
json=req_data,
|
||||||
|
headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def handle_completions(request: Request):
|
||||||
|
try:
|
||||||
|
req_data = await request.json()
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Get the next prefill client in round-robin fashion
|
||||||
|
prefill_client_info = get_next_client(request.app, 'prefill')
|
||||||
|
|
||||||
|
# Send request to prefill service
|
||||||
|
response = await send_request_to_service(prefill_client_info,
|
||||||
|
"/completions", req_data,
|
||||||
|
request_id)
|
||||||
|
|
||||||
|
# Extract the needed fields
|
||||||
|
response_json = response.json()
|
||||||
|
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||||
|
if kv_transfer_params:
|
||||||
|
req_data["kv_transfer_params"] = kv_transfer_params
|
||||||
|
|
||||||
|
# Get the next decode client in round-robin fashion
|
||||||
|
decode_client_info = get_next_client(request.app, 'decode')
|
||||||
|
|
||||||
|
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
|
||||||
|
|
||||||
|
# Stream response from decode service
|
||||||
|
async def generate_stream():
|
||||||
|
async for chunk in stream_service_response(decode_client_info,
|
||||||
|
"/completions",
|
||||||
|
req_data,
|
||||||
|
request_id=request_id):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(generate_stream(),
|
||||||
|
media_type="application/json")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
print("Error occurred in disagg prefill proxy server"
|
||||||
|
" - completions endpoint")
|
||||||
|
print(e)
|
||||||
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/healthcheck")
|
||||||
|
async def healthcheck():
|
||||||
|
"""Simple endpoint to check if the server is running."""
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"prefill_instances": len(app.state.prefill_clients),
|
||||||
|
"decode_instances": len(app.state.decode_clients)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
global global_args
|
||||||
|
global_args = parse_args()
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||||
0
tests/v1/kv_connector/unit/__init__.py
Normal file
0
tests/v1/kv_connector/unit/__init__.py
Normal file
73
tests/v1/kv_connector/unit/test_nixl_connector.py
Normal file
73
tests/v1/kv_connector/unit/test_nixl_connector.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||||
|
NixlConnectorMetadata)
|
||||||
|
|
||||||
|
from .utils import create_request, create_scheduler, create_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_inferface():
|
||||||
|
"""Unit test for basic NixlConnector interface functionality."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# 2 Full Blocks and 1 Half Block.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request = create_request(request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=True)
|
||||||
|
request_id = request.request_id
|
||||||
|
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Remote Prefill, triggers NixlConnectorMetdata.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
|
assert kv_connector_metadata is not None
|
||||||
|
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
|
||||||
|
|
||||||
|
assert len(kv_connector_metadata.requests) == 1
|
||||||
|
assert request_id in kv_connector_metadata.requests
|
||||||
|
req_meta = kv_connector_metadata.requests[request_id]
|
||||||
|
|
||||||
|
for block_id, block in zip(
|
||||||
|
req_meta.local_block_ids, scheduler.kv_cache_manager.
|
||||||
|
single_type_manager.req_to_blocks[request_id]):
|
||||||
|
assert block_id == block.block_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_less_than_block_size():
|
||||||
|
"""
|
||||||
|
Test that we can handle case where prompt is < block.
|
||||||
|
|
||||||
|
In this case, the P worker will send empty remote_block_ids.
|
||||||
|
The D worker should not schedule an async read in this case,
|
||||||
|
since there is nothing to pull.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# Half of a block.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
|
||||||
|
|
||||||
|
# Request will have 0 remote blocks.
|
||||||
|
request = create_request(request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=True,
|
||||||
|
num_remote_blocks=0)
|
||||||
|
scheduler.add_request(request)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
|
||||||
|
# This request should not have to read async.
|
||||||
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
|
assert kv_connector_metadata is not None
|
||||||
|
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
|
||||||
|
assert len(kv_connector_metadata.requests) == 0
|
||||||
|
|
||||||
|
# This request should be scheduled regularly.
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
181
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Normal file
181
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
from vllm.v1.request import FinishReason, RequestStatus
|
||||||
|
|
||||||
|
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||||
|
create_request, create_scheduler, create_vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_lifecycle():
|
||||||
|
"""Test lifecycle of a Remote Decode request."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# 2 Full Blocks and 1 Half Block.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request = create_request(request_id=1,
|
||||||
|
max_tokens=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_decode=True)
|
||||||
|
|
||||||
|
scheduler.add_request(request)
|
||||||
|
request_id = request.request_id
|
||||||
|
|
||||||
|
# STEP (1): Prefill.
|
||||||
|
# (1a): schedule()
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
|
||||||
|
# (1b): execute_model()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request])
|
||||||
|
|
||||||
|
# (1c): update_from_output()
|
||||||
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
|
model_runner_output)
|
||||||
|
|
||||||
|
# Ensure the request is finished after 1 tokens.
|
||||||
|
assert request.is_finished()
|
||||||
|
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
output = engine_core_outputs.outputs[0]
|
||||||
|
assert output.finish_reason == FinishReason.LENGTH
|
||||||
|
assert output.kv_transfer_params is not None
|
||||||
|
|
||||||
|
# Request freed in Scheduler and in Persistent Batch ...
|
||||||
|
assert request_id in scheduler.finished_req_ids
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
# ... but blocks should not be freed.
|
||||||
|
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||||
|
request_id]
|
||||||
|
for block in blocks:
|
||||||
|
assert block.ref_cnt == 1
|
||||||
|
|
||||||
|
# STEP (2): Send Finished to PB.
|
||||||
|
# (2a): schedule() - pass finished request to PB.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler_output.finished_req_ids) == 1
|
||||||
|
assert request_id in scheduler_output.finished_req_ids
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
|
||||||
|
# (2b): execute_model()
|
||||||
|
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
# (2c): update_from_output()
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# STEP (3): Finished sending.
|
||||||
|
# (3a): schedule() - pass finished request to PB.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler_output.finished_req_ids) == 0
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
|
||||||
|
# (3b): execute_model()
|
||||||
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
model_runner_output.finished_sending = [request_id]
|
||||||
|
|
||||||
|
# (3c): update_from_output()
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Confirm we do not have any memory leaks after req lifecycle.
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def test_short_prompt_lifecycle():
|
||||||
|
"""Test lifecycle of a Remote Decode request with short prompt."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# Not enough tokens for full block.
|
||||||
|
NUM_TOKENS = vllm_config.cache_config.block_size // 2
|
||||||
|
request = create_request(request_id=1,
|
||||||
|
max_tokens=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_decode=True)
|
||||||
|
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# STEP (1): Prefill.
|
||||||
|
# (1a): schedule()
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
|
||||||
|
# (1b): execute_model()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request])
|
||||||
|
|
||||||
|
# (1c): update_from_output()
|
||||||
|
# Since tokens < block_size, there will be no kv xfer.
|
||||||
|
# So this should be cleaned up immediately.
|
||||||
|
_ = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Confirm we do not have any memory leaks after req lifecycle.
|
||||||
|
# We need one more call to schedule() to clear data for persistent batch.
|
||||||
|
_ = scheduler.schedule()
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_lifecycle():
|
||||||
|
"""Test that remote decode params still works with a prefix cache hit."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# Prime the KVCache.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS)
|
||||||
|
|
||||||
|
scheduler.add_request(request_normal)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
||||||
|
use_eos=True)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
scheduler.schedule()
|
||||||
|
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
|
||||||
|
#####################
|
||||||
|
# Actual Test: confirm we send all blocks.
|
||||||
|
|
||||||
|
# Step (1): Send the KV Transfer.
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request_remote = create_request(request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_decode=True)
|
||||||
|
|
||||||
|
scheduler.add_request(request_remote)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||||
|
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
kv_transfer_params = eco.outputs[0].kv_transfer_params
|
||||||
|
|
||||||
|
# Ensure we send all block ids, even if there is a cache hit.
|
||||||
|
assert (len(
|
||||||
|
kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS)
|
||||||
|
|
||||||
|
# STEP (2): Ensure it is freed.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
scheduler.schedule()
|
||||||
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
model_runner_output.finished_sending = [request_remote.request_id]
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
_ = scheduler.schedule()
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
342
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Normal file
342
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
from vllm.v1.request import FinishReason, RequestStatus
|
||||||
|
|
||||||
|
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||||
|
create_request, create_scheduler, create_vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_lifecycle():
|
||||||
|
"""Test lifecycle of a remote prefill."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# 2 Full Blocks and 1 Half Block.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||||
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||||
|
|
||||||
|
request = create_request(request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=True)
|
||||||
|
|
||||||
|
scheduler.add_request(request)
|
||||||
|
request_id = request.request_id
|
||||||
|
|
||||||
|
# STEP (1):
|
||||||
|
# (1a): schedule()
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Nothing running and empty scheduler output.
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||||
|
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||||
|
|
||||||
|
# Req waiting for KVs with no computed/scheduled toks ...
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert request in scheduler.waiting
|
||||||
|
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||||
|
assert (request.num_computed_tokens == 0)
|
||||||
|
|
||||||
|
# ... but should have (uncached) blocks allocated to it.
|
||||||
|
block_pool = scheduler.kv_cache_manager.block_pool
|
||||||
|
assert (block_pool.free_block_queue.num_free_blocks
|
||||||
|
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||||
|
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||||
|
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||||
|
request_id]
|
||||||
|
for block in blocks:
|
||||||
|
assert block._block_hash is None
|
||||||
|
|
||||||
|
# (1b): forward()
|
||||||
|
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
# (1c): update_from_output()
|
||||||
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
|
model_runner_output)
|
||||||
|
assert len(engine_core_outputs.outputs) == 0
|
||||||
|
|
||||||
|
# STEP (2):
|
||||||
|
# (2a): schedule(): nothing happens!
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
|
||||||
|
# (2b): forward(): request finishes recv.
|
||||||
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
model_runner_output.finished_recving = [request_id]
|
||||||
|
|
||||||
|
# (2c): update_from_output():
|
||||||
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
|
model_runner_output)
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||||
|
|
||||||
|
# STEP (3):
|
||||||
|
# (3a): schedule(): this should actually schedule.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
|
||||||
|
# Confirm the block are actually allocated.
|
||||||
|
num_hashed_blocks = 0
|
||||||
|
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||||
|
request_id]
|
||||||
|
for block in blocks:
|
||||||
|
assert block.ref_cnt == 1
|
||||||
|
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||||
|
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||||
|
|
||||||
|
# Confirm the rest of the prompt is scheduled in this step.
|
||||||
|
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||||
|
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||||
|
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||||
|
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
||||||
|
|
||||||
|
# (3b): execute_model()
|
||||||
|
model_runner_output = create_model_runner_output([request])
|
||||||
|
# (3c): update_from_output()
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Step (4): Hit EOS.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||||
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
|
model_runner_output)
|
||||||
|
scheduler.schedule()
|
||||||
|
|
||||||
|
outputs = engine_core_outputs.outputs
|
||||||
|
assert len(outputs) == 1
|
||||||
|
output = outputs[0]
|
||||||
|
assert output.finish_reason == FinishReason.STOP
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interleaved_lifecycle():
|
||||||
|
"""Test Remote Prefills Work Well With Other Requests."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# 2 Full Blocks and 1 Half Block.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request_remote = create_request(request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=True)
|
||||||
|
request_local_a = create_request(
|
||||||
|
request_id=2,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
)
|
||||||
|
request_local_b = create_request(
|
||||||
|
request_id=3,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# STEP 1: Regular request is running.
|
||||||
|
scheduler.add_request(request_local_a)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output([request_local_a])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# STEP 2: Add a local and remote request.
|
||||||
|
scheduler.add_request(request_local_b)
|
||||||
|
scheduler.add_request(request_remote)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 1
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request_local_a, request_local_b])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# STEP 3: continue running, KVs not arrived yet.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
reqs=[request_local_a, request_local_b])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
||||||
|
|
||||||
|
# STEP 4: KVs arrive.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request_local_a, request_local_b],
|
||||||
|
finished_recving=[request_remote.request_id])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 3
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request_local_a, request_local_b, request_remote])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# STEP 6: Hit EOS and free.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request_local_a, request_local_b, request_remote],
|
||||||
|
use_eos=True,
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
scheduler.schedule()
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_spurious_prefix_caching():
|
||||||
|
"""
|
||||||
|
With P/D, blocks can be allocated but uncomputed for
|
||||||
|
multiple engine steps. This test confirms that we do
|
||||||
|
not accidentally have cache hits against uncomputed
|
||||||
|
blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# 2 and a half full external blocks.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||||
|
request_remote = create_request(
|
||||||
|
request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=True,
|
||||||
|
use_all_1s_for_prompt_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_local = create_request(
|
||||||
|
request_id=2,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=False,
|
||||||
|
use_all_1s_for_prompt_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schedule the remote prefill request. This should not
|
||||||
|
# cause any blocks to be cached.
|
||||||
|
scheduler.add_request(request_remote)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
# Schedule the local prefill request. This should
|
||||||
|
# cause blocks to be cached, but separately from
|
||||||
|
scheduler.add_request(request_local)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||||
|
request_local.request_id]
|
||||||
|
remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501
|
||||||
|
request_remote.request_id]
|
||||||
|
|
||||||
|
# Local should have cached blocks (but not all due to preallocate).
|
||||||
|
num_hashed_blocks = 0
|
||||||
|
for block in local_blocks:
|
||||||
|
assert block.ref_cnt == 1
|
||||||
|
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||||
|
assert num_hashed_blocks > 0
|
||||||
|
|
||||||
|
# Remote blocks should not be cached.
|
||||||
|
for block in remote_blocks:
|
||||||
|
assert block.ref_cnt == 1
|
||||||
|
assert block._block_hash is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_block_prompt():
|
||||||
|
"""Test that we handle a prompt that is the full block size."""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# 2 Full Blocks and 1 Half Block.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||||
|
|
||||||
|
request = create_request(request_id=1,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_prefill=True)
|
||||||
|
|
||||||
|
scheduler.add_request(request)
|
||||||
|
request_id = request.request_id
|
||||||
|
|
||||||
|
# STEP (1): Initialize a recv.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
# All blocks should be allocated.
|
||||||
|
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
|
||||||
|
req_to_blocks[request_id])
|
||||||
|
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||||
|
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# # STEP (2): Recv.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
model_runner_output.finished_recving = [request_id]
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||||
|
|
||||||
|
# # STEP (3): Run as usual.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
|
||||||
|
# We need to recompute the final token of the prompt to generate
|
||||||
|
# the first new token, so we should not have a new block.
|
||||||
|
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
|
||||||
|
req_to_blocks[request_id])
|
||||||
|
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||||
|
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||||
|
NUM_TOKENS - 1)
|
||||||
|
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output([request])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# # Step (4): Hit EOS.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||||
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
|
model_runner_output)
|
||||||
|
scheduler.schedule()
|
||||||
|
|
||||||
|
outputs = engine_core_outputs.outputs
|
||||||
|
assert len(outputs) == 1
|
||||||
|
output = outputs[0]
|
||||||
|
assert output.finish_reason == FinishReason.STOP
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
190
tests/v1/kv_connector/unit/utils.py
Normal file
190
tests/v1/kv_connector/unit/utils.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||||
|
ModelConfig, SchedulerConfig, VllmConfig)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||||
|
NixlKVTransferParams)
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheGroupSpec)
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
|
|
||||||
|
def assert_scheduler_empty(scheduler: Scheduler):
|
||||||
|
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||||
|
# Scheduler Metadata.
|
||||||
|
assert len(scheduler.requests) == 0
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||||
|
assert len(scheduler._cached_reqs_data) == 0
|
||||||
|
|
||||||
|
# EncoderCacheManager.
|
||||||
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||||
|
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||||
|
|
||||||
|
# KVCache Manager.
|
||||||
|
assert len(
|
||||||
|
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
|
||||||
|
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||||
|
assert len(
|
||||||
|
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
|
||||||
|
num_free_blocks = (
|
||||||
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||||
|
assert num_free_blocks == (
|
||||||
|
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||||
|
|
||||||
|
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||||
|
# value, etc will remain since we lazily evict for prefix cache.
|
||||||
|
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||||
|
assert block.ref_cnt == 0
|
||||||
|
|
||||||
|
|
||||||
|
def create_vllm_config(
|
||||||
|
model: str = "facebook/opt-125m",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 64,
|
||||||
|
block_size: int = 16,
|
||||||
|
) -> VllmConfig:
|
||||||
|
"""Initialize VllmConfig For Testing."""
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
# Cache config, optionally force APC
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=block_size,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
)
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="NixlConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
)
|
||||||
|
return VllmConfig(scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
device_config=DeviceConfig("cpu"))
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
num_blocks: int = 10000,
|
||||||
|
) -> Scheduler:
|
||||||
|
"""Initialize Scheduler For Testing."""
|
||||||
|
block_size = vllm_config.cache_config.block_size
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||||
|
tensors={},
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(['layer'],
|
||||||
|
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||||
|
False))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||||
|
return Scheduler(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_request(
|
||||||
|
request_id: int,
|
||||||
|
num_tokens: int = 10,
|
||||||
|
max_tokens: int = 16,
|
||||||
|
do_remote_decode: bool = False,
|
||||||
|
do_remote_prefill: bool = False,
|
||||||
|
use_all_1s_for_prompt_tokens: bool = False,
|
||||||
|
num_remote_blocks: int = 3,
|
||||||
|
) -> Request:
|
||||||
|
"""Make dummy request for testing."""
|
||||||
|
|
||||||
|
if do_remote_decode:
|
||||||
|
assert not do_remote_prefill
|
||||||
|
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
|
||||||
|
do_remote_decode=True)
|
||||||
|
elif do_remote_prefill:
|
||||||
|
kv_transfer_params = NixlKVTransferParams(
|
||||||
|
do_remote_prefill=True,
|
||||||
|
do_remote_decode=False,
|
||||||
|
remote_engine_id="my-engine-id",
|
||||||
|
remote_block_ids=list(range(num_remote_blocks)),
|
||||||
|
remote_host="my-host",
|
||||||
|
remote_port=1234)
|
||||||
|
else:
|
||||||
|
kv_transfer_params = None
|
||||||
|
|
||||||
|
max_tokens = 1 if do_remote_decode else max_tokens
|
||||||
|
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||||
|
|
||||||
|
if use_all_1s_for_prompt_tokens:
|
||||||
|
prompt_token_ids = [1] * num_tokens
|
||||||
|
else:
|
||||||
|
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||||
|
|
||||||
|
req = Request(
|
||||||
|
request_id=f"id-{request_id}",
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
multi_modal_inputs=None,
|
||||||
|
multi_modal_placeholders=None,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
arrival_time=0,
|
||||||
|
)
|
||||||
|
req.kv_transfer_params = kv_transfer_params
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_runner_output(
|
||||||
|
reqs: list[Request],
|
||||||
|
finished_sending: Optional[list[str]] = None,
|
||||||
|
finished_recving: Optional[list[str]] = None,
|
||||||
|
use_eos: bool = False,
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
"""Make dummy model runner output for testing."""
|
||||||
|
|
||||||
|
# Make request data.
|
||||||
|
req_ids = [req.request_id for req in reqs]
|
||||||
|
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||||
|
|
||||||
|
# Make sampled tokens.
|
||||||
|
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||||
|
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||||
|
|
||||||
|
# Make output data structure.
|
||||||
|
return ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_id_to_index,
|
||||||
|
sampled_token_ids=sampled_token_ids,
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
|
)
|
||||||
@ -8,6 +8,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -3438,6 +3439,9 @@ class KVTransferConfig:
|
|||||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
engine_id: str = str(uuid.uuid4())
|
||||||
|
"""The engine id for KV transfers."""
|
||||||
|
|
||||||
kv_buffer_device: Optional[str] = "cuda"
|
kv_buffer_device: Optional[str] = "cuda"
|
||||||
"""The device used by kv connector to buffer the KV cache.
|
"""The device used by kv connector to buffer the KV cache.
|
||||||
Currently only support 'cuda'."""
|
Currently only support 'cuda'."""
|
||||||
@ -3448,7 +3452,7 @@ class KVTransferConfig:
|
|||||||
|
|
||||||
kv_role: Optional[KVRole] = None
|
kv_role: Optional[KVRole] = None
|
||||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||||
are 'kv_producer', 'kv_consumer', and 'both'."""
|
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||||
|
|
||||||
kv_rank: Optional[int] = None
|
kv_rank: Optional[int] = None
|
||||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||||
|
|||||||
@ -105,3 +105,8 @@ KVConnectorFactory.register_connector(
|
|||||||
"LMCacheConnectorV1",
|
"LMCacheConnectorV1",
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||||
"LMCacheConnectorV1")
|
"LMCacheConnectorV1")
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector(
|
||||||
|
"NixlConnector",
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||||
|
"NixlConnector")
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
|
||||||
"KVConnectorRole",
|
|
||||||
"KVConnectorBase_V1",
|
|
||||||
]
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ The class provides the following primitives:
|
|||||||
import enum
|
import enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -34,6 +34,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -47,12 +48,34 @@ class KVConnectorRole(enum.Enum):
|
|||||||
WORKER = 1
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
|
class KVTransferParams:
|
||||||
|
"""
|
||||||
|
Abstract KVTransferParams used to send KVTransfer
|
||||||
|
parameters between instances of vLLM.
|
||||||
|
|
||||||
|
Specific instances of KVConnector customize this
|
||||||
|
method for serializing / deserializing msgs sent
|
||||||
|
via the HTTP protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw_dict(
|
||||||
|
raw_dict: Optional[dict[str,
|
||||||
|
Any]]) -> Optional["KVTransferParams"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KVConnectorMetadata:
|
class KVConnectorMetadata:
|
||||||
|
"""
|
||||||
|
Abstract Metadata used to communicate between the
|
||||||
|
Scheduler KVConnector and Worker KVConnector.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class KVConnectorBase_V1(ABC):
|
class KVConnectorBase_V1(ABC):
|
||||||
|
_KVTransferParams = KVTransferParams
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -66,6 +89,10 @@ class KVConnectorBase_V1(ABC):
|
|||||||
def role(self) -> KVConnectorRole:
|
def role(self) -> KVConnectorRole:
|
||||||
return self._role
|
return self._role
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Worker-side methods
|
||||||
|
# ==============================
|
||||||
|
|
||||||
def bind_connector_metadata(
|
def bind_connector_metadata(
|
||||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
"""Set the connector metadata from the scheduler.
|
"""Set the connector metadata from the scheduler.
|
||||||
@ -97,9 +124,15 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return self._connector_metadata
|
return self._connector_metadata
|
||||||
|
|
||||||
# ==============================
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
# Worker-side methods
|
"""
|
||||||
# ==============================
|
Initialize with the KV caches. Useful for pre-registering the
|
||||||
|
KV Caches in the KVConnector (e.g. for NIXL).
|
||||||
|
|
||||||
|
Args: kv_caches:
|
||||||
|
dictionary of layer names, kv cache
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
@ -162,15 +195,37 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_finished(
|
||||||
|
self, finished_req_ids: set[str]
|
||||||
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
|
"""
|
||||||
|
Notifies worker-side connector ids of requests that have
|
||||||
|
finished generating tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ids of requests that have finished asynchronous (recving, sending).
|
||||||
|
The finished saves/sends req ids must belong to a set provided in a
|
||||||
|
call to this method (this call or a prior one).
|
||||||
|
"""
|
||||||
|
return None, None
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
|
def set_kv_transfer_params(self, request: "Request"):
|
||||||
|
"""Parse raw KV Transfer params."""
|
||||||
|
assert request.kv_transfer_params is None
|
||||||
|
kv_transfer_params = self._KVTransferParams.from_raw_dict(
|
||||||
|
request.raw_kv_transfer_params)
|
||||||
|
request.kv_transfer_params = kv_transfer_params
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
) -> int:
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Get number of new tokens that can be loaded from the
|
Get number of new tokens that can be loaded from the
|
||||||
external KV cache beyond the num_computed_tokens.
|
external KV cache beyond the num_computed_tokens.
|
||||||
@ -181,13 +236,16 @@ class KVConnectorBase_V1(ABC):
|
|||||||
computed tokens for this request
|
computed tokens for this request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the number of tokens that can be loaded from the
|
* the number of tokens that can be loaded from the
|
||||||
external KV cache beyond what is already computed.
|
external KV cache beyond what is already computed.
|
||||||
|
* true if external KV cache tokens will be loaded
|
||||||
|
asynchronously (between scheduler steps).
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
num_external_tokens: int):
|
num_external_tokens: int):
|
||||||
"""
|
"""
|
||||||
Update KVConnector state after block allocation.
|
Update KVConnector state after block allocation.
|
||||||
@ -207,3 +265,20 @@ class KVConnectorBase_V1(ABC):
|
|||||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
block_ids: list[int],
|
||||||
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Called when a request has finished, before its blocks are freed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the request is being saved/sent asynchronously and blocks
|
||||||
|
should not be freed until the request_id is returned from
|
||||||
|
get_finished().
|
||||||
|
Optional KVTransferParams to be included in the request outputs
|
||||||
|
returned by the engine.
|
||||||
|
"""
|
||||||
|
return False, None
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -92,7 +93,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
) -> int:
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Get number of new tokens that can be loaded from the
|
Get number of new tokens that can be loaded from the
|
||||||
external KV cache beyond the num_computed_tokens.
|
external KV cache beyond the num_computed_tokens.
|
||||||
@ -107,9 +108,10 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
external KV cache beyond what is already computed.
|
external KV cache beyond what is already computed.
|
||||||
"""
|
"""
|
||||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||||
request, num_computed_tokens)
|
request, num_computed_tokens), False
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
num_external_tokens: int):
|
num_external_tokens: int):
|
||||||
"""
|
"""
|
||||||
Update KVConnector state after block allocation.
|
Update KVConnector state after block allocation.
|
||||||
|
|||||||
805
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
805
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
@ -0,0 +1,805 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import contextlib
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterator
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
from typing_extensions import Optional
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||||
|
get_tp_group)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import round_down
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
GET_META_MSG = b"get_meta_msg"
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||||
|
try:
|
||||||
|
from nixl._api import nixl_agent as NixlWrapper
|
||||||
|
logger.info("NIXL is available")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("NIXL is not available")
|
||||||
|
NixlWrapper = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NixlKVTransferParams(KVTransferParams):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_remote_prefill: bool,
|
||||||
|
do_remote_decode: bool,
|
||||||
|
remote_block_ids: Optional[list[int]] = None,
|
||||||
|
remote_host: Optional[str] = None,
|
||||||
|
remote_port: Optional[int] = None,
|
||||||
|
remote_engine_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.do_remote_prefill = do_remote_prefill
|
||||||
|
self.do_remote_decode = do_remote_decode
|
||||||
|
self.remote_block_ids = remote_block_ids
|
||||||
|
self.remote_host = remote_host
|
||||||
|
self.remote_port = remote_port
|
||||||
|
self.remote_engine_id = remote_engine_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw_dict(
|
||||||
|
raw_dict: Optional[dict[str,
|
||||||
|
Any]]) -> Optional["NixlKVTransferParams"]:
|
||||||
|
|
||||||
|
# If no raw transfer params passed, return None.
|
||||||
|
if raw_dict is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate the request is formatted properly.
|
||||||
|
if (("do_remote_prefill" not in raw_dict)
|
||||||
|
or ("do_remote_decode" not in raw_dict)
|
||||||
|
or ("remote_block_ids" not in raw_dict)
|
||||||
|
or ("remote_host" not in raw_dict)
|
||||||
|
or ("remote_port" not in raw_dict)
|
||||||
|
or ("remote_engine_id" not in raw_dict)):
|
||||||
|
logger.warning(
|
||||||
|
"Got invalid KVTransferParams: %s. This "
|
||||||
|
"request will not utilize KVTransfer", raw_dict)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return NixlKVTransferParams(
|
||||||
|
do_remote_prefill=raw_dict["do_remote_prefill"],
|
||||||
|
do_remote_decode=raw_dict["do_remote_decode"],
|
||||||
|
remote_block_ids=raw_dict["remote_block_ids"],
|
||||||
|
remote_host=raw_dict["remote_host"],
|
||||||
|
remote_port=raw_dict["remote_port"],
|
||||||
|
remote_engine_id=raw_dict["remote_engine_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NixlAgentMetadata(
|
||||||
|
msgspec.Struct,
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
# required for @cached_property.
|
||||||
|
dict=True):
|
||||||
|
engine_id: str
|
||||||
|
agent_metadata: bytes
|
||||||
|
kv_caches_base_addr: list[int]
|
||||||
|
num_blocks: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReqMeta:
|
||||||
|
local_block_ids: list[int]
|
||||||
|
remote_block_ids: list[int]
|
||||||
|
remote_host: str
|
||||||
|
remote_port: int
|
||||||
|
remote_engine_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.requests: dict[str, ReqMeta] = {}
|
||||||
|
|
||||||
|
def add_new_req(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
local_block_ids: list[int],
|
||||||
|
kv_transfer_params: NixlKVTransferParams,
|
||||||
|
):
|
||||||
|
assert request_id not in self.requests
|
||||||
|
assert kv_transfer_params.remote_block_ids is not None
|
||||||
|
assert kv_transfer_params.remote_engine_id is not None
|
||||||
|
assert kv_transfer_params.remote_host is not None
|
||||||
|
assert kv_transfer_params.remote_port is not None
|
||||||
|
|
||||||
|
self.requests[request_id] = ReqMeta(
|
||||||
|
local_block_ids=local_block_ids,
|
||||||
|
remote_block_ids=kv_transfer_params.remote_block_ids,
|
||||||
|
remote_engine_id=kv_transfer_params.remote_engine_id,
|
||||||
|
remote_host=kv_transfer_params.remote_host,
|
||||||
|
remote_port=kv_transfer_params.remote_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NixlConnector(KVConnectorBase_V1):
|
||||||
|
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||||
|
assert vllm_config.kv_transfer_config is not None
|
||||||
|
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||||
|
|
||||||
|
if role == KVConnectorRole.SCHEDULER:
|
||||||
|
self.connector_scheduler : Optional[NixlConnectorScheduler] = \
|
||||||
|
NixlConnectorScheduler(vllm_config, str(self.engine_id))
|
||||||
|
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||||
|
elif role == KVConnectorRole.WORKER:
|
||||||
|
self.connector_scheduler = None
|
||||||
|
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
# Scheduler Side Methods
|
||||||
|
############################################################
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: "Request",
|
||||||
|
num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||||
|
request, num_computed_tokens)
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int):
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
return self.connector_scheduler.update_state_after_alloc(
|
||||||
|
request, blocks, num_external_tokens)
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
block_ids: list[int],
|
||||||
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
return self.connector_scheduler.request_finished(request, block_ids)
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
# Worker Side Methods
|
||||||
|
############################################################
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
def get_finished(self,
|
||||||
|
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||||
|
"""Get the finished recving and sending requests."""
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
return self.connector_worker.get_finished()
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
|
**kwargs) -> None:
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
|
||||||
|
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
"""NixlConnector does not do layerwise saving."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
"""NixlConnector does not save explicitly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
"""NixlConnector does not save explicitly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NixlConnectorScheduler:
|
||||||
|
"""Implementation of Scheduler side methods"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.engine_id = engine_id
|
||||||
|
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||||
|
|
||||||
|
# Requests that need to start recv.
|
||||||
|
# New requests are added by update_state_after_alloc in
|
||||||
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
|
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: "Request",
|
||||||
|
num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
|
"""
|
||||||
|
For remote prefill, pull all prompt blocks from remote
|
||||||
|
asynchronously relative to engine execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
num_computed_tokens (int): the number of locally
|
||||||
|
computed tokens for this request
|
||||||
|
Returns:
|
||||||
|
* the number of tokens that can be loaded from the
|
||||||
|
external KV cache beyond what is already computed.
|
||||||
|
* true if the external KV cache tokens will be loaded
|
||||||
|
asynchronously (between scheduler steps).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No KVTransfer for this request.
|
||||||
|
if request.kv_transfer_params is None:
|
||||||
|
return 0, False
|
||||||
|
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||||
|
|
||||||
|
# Remote prefill: get all prompt blocks from remote.
|
||||||
|
if request.kv_transfer_params.do_remote_prefill:
|
||||||
|
assert num_computed_tokens % self.block_size == 0
|
||||||
|
rounded_num_prompt_tokens = round_down(
|
||||||
|
len(request.prompt_token_ids), self.block_size)
|
||||||
|
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
||||||
|
return count, count > 0
|
||||||
|
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int):
|
||||||
|
if request.kv_transfer_params is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||||
|
if request.kv_transfer_params.do_remote_prefill:
|
||||||
|
# NOTE(rob): if prompt < block_size, no remote blocks
|
||||||
|
# since the remote only sends fully computed blocks, so
|
||||||
|
# skip recving for this request. num_external_tokens
|
||||||
|
# should be 0 if there are no remote blocks.
|
||||||
|
if request.kv_transfer_params.remote_block_ids:
|
||||||
|
# Get unhashed blocks to pull from remote.
|
||||||
|
self._reqs_need_recv[request.request_id] = (
|
||||||
|
request, blocks.get_unhashed_block_ids())
|
||||||
|
else:
|
||||||
|
assert num_external_tokens == 0
|
||||||
|
# Only trigger 1 KV transfer per request.
|
||||||
|
request.kv_transfer_params.do_remote_prefill = False
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
meta = NixlConnectorMetadata()
|
||||||
|
|
||||||
|
# Loop through scheduled reqs and convert to ReqMeta.
|
||||||
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||||
|
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
|
||||||
|
meta.add_new_req(
|
||||||
|
request_id=req_id,
|
||||||
|
local_block_ids=block_ids,
|
||||||
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear the list once workers start the transfers
|
||||||
|
self._reqs_need_recv.clear()
|
||||||
|
|
||||||
|
return meta
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
block_ids: list[int],
|
||||||
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Once a request is finished, determine whether request blocks
|
||||||
|
should be freed now or will be sent asynchronously and freed later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if request.kv_transfer_params is None:
|
||||||
|
return False, None
|
||||||
|
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||||
|
|
||||||
|
if ((not request.kv_transfer_params.do_remote_decode)
|
||||||
|
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
# Get computed blocks.
|
||||||
|
all_full = request.num_computed_tokens % self.block_size == 0
|
||||||
|
computed_block_ids = (block_ids if all_full else block_ids[:-1])
|
||||||
|
|
||||||
|
# If prompt < block_size, no xfer so free blocks immediately.
|
||||||
|
delay_free_blocks = len(computed_block_ids) > 0
|
||||||
|
|
||||||
|
return delay_free_blocks, NixlKVTransferParams(
|
||||||
|
do_remote_prefill=True,
|
||||||
|
do_remote_decode=False,
|
||||||
|
remote_block_ids=computed_block_ids,
|
||||||
|
remote_engine_id=self.engine_id,
|
||||||
|
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
||||||
|
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
||||||
|
).__dict__
|
||||||
|
|
||||||
|
|
||||||
|
class NixlConnectorWorker:
|
||||||
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
|
def __init__(self, engine_id: str):
|
||||||
|
if NixlWrapper is None:
|
||||||
|
logger.error("NIXL is not available")
|
||||||
|
raise RuntimeError("NIXL is not available")
|
||||||
|
logger.info("Initializing NIXL wrapper")
|
||||||
|
logger.info("Initializing NIXL worker %s", engine_id)
|
||||||
|
|
||||||
|
# Agent.
|
||||||
|
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||||
|
# Map of engine_id -> agent_name.
|
||||||
|
self._remote_agents: dict[str, str] = {}
|
||||||
|
|
||||||
|
# Metadata.
|
||||||
|
self.engine_id = engine_id
|
||||||
|
self.rank = get_tensor_model_parallel_rank()
|
||||||
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
|
# KV Caches and nixl tracking data.
|
||||||
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Map of engine_id -> kv_caches_base_addr
|
||||||
|
self.kv_caches_base_addr: dict[str, list[int]] = {}
|
||||||
|
|
||||||
|
# Number of NIXL regions. Currently one region per cache
|
||||||
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
|
self.num_regions = 0
|
||||||
|
|
||||||
|
# nixl_prepped_dlist_handle (int).
|
||||||
|
self.src_xfer_side_handle: int = 0
|
||||||
|
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||||
|
self.dst_xfer_side_handles: dict[str, int] = {}
|
||||||
|
|
||||||
|
# Map of engine_id -> num_blocks.
|
||||||
|
self.dst_num_blocks: dict[str, int] = {}
|
||||||
|
self._registered_descs: list[Any] = []
|
||||||
|
|
||||||
|
# In progress transfers.
|
||||||
|
# [req_id -> list[handle]]
|
||||||
|
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
|
||||||
|
list[Any])
|
||||||
|
|
||||||
|
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||||
|
# transactions on ranks 1 to N-1.
|
||||||
|
# [req_id -> count]
|
||||||
|
self._done_recving_count: defaultdict[str,
|
||||||
|
int] = defaultdict(lambda: 0)
|
||||||
|
self._done_sending_count: defaultdict[str,
|
||||||
|
int] = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
# Background thread for establishing new connections.
|
||||||
|
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||||
|
ready_event: threading.Event, rank: int):
|
||||||
|
"""Background thread for getting new NIXL handshakes."""
|
||||||
|
# NOTE(rob): this is a simple implementation. We will move
|
||||||
|
# to a better approach like an ETCD server in the future.
|
||||||
|
|
||||||
|
# NOTE(rob): to support heterogeneous TP, we will have to
|
||||||
|
# move this into the scheduler rather than worker, since
|
||||||
|
# each rank needs the metadata of all other ranks (whereas
|
||||||
|
# in this setup, each rank only gets one other rank's meta.
|
||||||
|
|
||||||
|
encoder = msgspec.msgpack.Encoder()
|
||||||
|
encoded_data = encoder.encode(metadata)
|
||||||
|
size_in_bytes = len(encoded_data)
|
||||||
|
logger.debug("Size of encoded NixlAgentMetadata: %s bytes",
|
||||||
|
str(size_in_bytes))
|
||||||
|
|
||||||
|
# Listen for new requests for metadata.
|
||||||
|
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
|
# NOTE(rob): we need each rank to have a unique port. This
|
||||||
|
# hack to keeps us moving. We will switch when moving to etcd
|
||||||
|
# or where we have a single ZMQ socket in the scheduler.
|
||||||
|
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
|
||||||
|
path = f"tcp://{host}:{port}"
|
||||||
|
logger.debug("Starting listening on path: %s", path)
|
||||||
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
|
ready_event.set()
|
||||||
|
while True:
|
||||||
|
identity, _, msg = sock.recv_multipart()
|
||||||
|
if msg != GET_META_MSG:
|
||||||
|
logger.warning(
|
||||||
|
"Connection listener got unexpected message %s", msg)
|
||||||
|
sock.send_multipart((identity, b"", encoded_data))
|
||||||
|
|
||||||
|
def _nixl_handshake(self, host: str, port: int):
|
||||||
|
"""Do a NIXL handshake with a remote instance."""
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
# NOTE(rob): we need each rank to have a unique port. This is
|
||||||
|
# a hack to keep us moving. We will switch when moving to etcd
|
||||||
|
# or where we have a single ZMQ socket in the scheduler.
|
||||||
|
path = f"tcp://{host}:{port + self.rank}"
|
||||||
|
logger.debug("Querying metadata on path: %s", path)
|
||||||
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
|
# Send query for the request.
|
||||||
|
sock.send(GET_META_MSG)
|
||||||
|
metadata_bytes = sock.recv()
|
||||||
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
|
metadata = decoder.decode(metadata_bytes)
|
||||||
|
got_metadata_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Register Remote agent.
|
||||||
|
self.add_remote_agent(metadata)
|
||||||
|
setup_agent_time = time.perf_counter()
|
||||||
|
|
||||||
|
logger.debug("NIXL handshake: get metadata took: %s",
|
||||||
|
got_metadata_time - start_time)
|
||||||
|
logger.debug("NIXL handshake: add agent took: %s",
|
||||||
|
setup_agent_time - got_metadata_time)
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
"""Register the KV Cache data in nixl."""
|
||||||
|
|
||||||
|
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||||
|
kv_elem_size = first_kv_cache.element_size()
|
||||||
|
|
||||||
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||||
|
use_mla = len(first_kv_cache.shape) == 3
|
||||||
|
if use_mla:
|
||||||
|
# MLA case.
|
||||||
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
|
block_rank = 2 # [block_size, latent_dim]
|
||||||
|
block_shape = first_kv_cache.shape[-block_rank:]
|
||||||
|
else:
|
||||||
|
# [2 (k and v), num_blocks, ...]
|
||||||
|
self.num_blocks = first_kv_cache.shape[1]
|
||||||
|
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||||
|
block_shape = first_kv_cache.shape[-block_rank:]
|
||||||
|
|
||||||
|
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||||
|
# hybrid attn, etc
|
||||||
|
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||||
|
|
||||||
|
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
|
||||||
|
first_kv_cache.shape)
|
||||||
|
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||||
|
block_shape)
|
||||||
|
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
|
||||||
|
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||||
|
self.kv_caches = kv_caches
|
||||||
|
kv_caches_base_addr = []
|
||||||
|
caches_data = []
|
||||||
|
|
||||||
|
# Note(tms): I modified this from the original region setup code.
|
||||||
|
# K and V are now in different regions. Advantage is that we can
|
||||||
|
# elegantly support MLA and any cases where the K and V tensors
|
||||||
|
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||||
|
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||||
|
# (roughly 8KB vs 5KB).
|
||||||
|
for cache_or_caches in kv_caches.values():
|
||||||
|
# Normalize to always be a list of caches
|
||||||
|
cache_list = [cache_or_caches] if use_mla else cache_or_caches
|
||||||
|
for cache in cache_list:
|
||||||
|
base_addr = cache.data_ptr()
|
||||||
|
region_len = self.num_blocks * self.block_len
|
||||||
|
caches_data.append((base_addr, region_len, self.rank, ""))
|
||||||
|
kv_caches_base_addr.append(base_addr)
|
||||||
|
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||||
|
self.num_regions = len(caches_data)
|
||||||
|
|
||||||
|
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||||
|
logger.debug("Registering descs: %s", caches_data)
|
||||||
|
self.nixl_wrapper.register_memory(descs)
|
||||||
|
logger.debug("Done registering descs")
|
||||||
|
|
||||||
|
self._registered_descs.append(descs)
|
||||||
|
|
||||||
|
# After KV Caches registered, listen for new connections.
|
||||||
|
metadata = NixlAgentMetadata(
|
||||||
|
engine_id=self.engine_id,
|
||||||
|
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||||
|
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||||
|
num_blocks=self.num_blocks,
|
||||||
|
)
|
||||||
|
ready_event = threading.Event()
|
||||||
|
self._nixl_handshake_listener_t = threading.Thread(
|
||||||
|
target=self._nixl_handshake_listener,
|
||||||
|
args=(metadata, ready_event, self.rank),
|
||||||
|
daemon=True,
|
||||||
|
name="nixl_handshake_listener")
|
||||||
|
self._nixl_handshake_listener_t.start()
|
||||||
|
ready_event.wait()
|
||||||
|
|
||||||
|
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
|
||||||
|
engine_id = nixl_agent_meta.engine_id
|
||||||
|
if engine_id in self._remote_agents:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
|
||||||
|
nixl_agent_meta.agent_metadata)
|
||||||
|
self.kv_caches_base_addr[
|
||||||
|
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||||
|
|
||||||
|
# Create src descs and xfer side handles.
|
||||||
|
blocks_data = []
|
||||||
|
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||||
|
for block_id in range(self.num_blocks):
|
||||||
|
block_offset = block_id * self.block_len
|
||||||
|
# (addr, len, device id)
|
||||||
|
blocks_data.append(
|
||||||
|
(base_addr + block_offset, self.block_len, self.rank))
|
||||||
|
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||||
|
len(blocks_data), self.engine_id, self.rank)
|
||||||
|
|
||||||
|
# Register with NIXL.
|
||||||
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||||
|
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||||
|
"NIXL_INIT_AGENT", descs)
|
||||||
|
|
||||||
|
# Create dst descs and xfer side handles.
|
||||||
|
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||||
|
blocks_data = []
|
||||||
|
for base_addr in self.kv_caches_base_addr[engine_id]:
|
||||||
|
for block_id in range(nixl_agent_meta.num_blocks):
|
||||||
|
block_offset = block_id * self.block_len
|
||||||
|
# (addr, len, device id)
|
||||||
|
blocks_data.append(
|
||||||
|
(base_addr + block_offset, self.block_len, self.rank))
|
||||||
|
logger.debug("Created %s blocks for dst engine %s and rank %s",
|
||||||
|
len(blocks_data), engine_id, self.rank)
|
||||||
|
|
||||||
|
# Register with NIXL.
|
||||||
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||||
|
self.dst_xfer_side_handles[
|
||||||
|
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||||
|
self._remote_agents[engine_id], descs)
|
||||||
|
|
||||||
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||||
|
"""
|
||||||
|
Get requests that are done sending or recving.
|
||||||
|
|
||||||
|
In TP>1 setup, each rank exchanges KVs with its counterpart
|
||||||
|
ranks independently. get_finished() runs in a worker creates
|
||||||
|
the done_sending and done_recving sets that are sent to the
|
||||||
|
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
|
||||||
|
are done before adding to finished, Ranks 1 to N-1 communicate
|
||||||
|
to Rank 0 once their transaction is done + Rank 0 returns
|
||||||
|
finished sets to Scheduler only once all ranks are done.
|
||||||
|
"""
|
||||||
|
done_sending = self._get_new_notifs()
|
||||||
|
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||||
|
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||||
|
logger.debug(
|
||||||
|
"Rank %s, get_finished: %s requests done sending "
|
||||||
|
"and %s requests done recving", self.rank, len(done_sending),
|
||||||
|
len(done_recving))
|
||||||
|
|
||||||
|
if self.world_size == 1:
|
||||||
|
return done_sending, done_recving
|
||||||
|
|
||||||
|
# Rank 0: get finished from all other ranks.
|
||||||
|
if self.rank == 0:
|
||||||
|
for req_id in done_sending:
|
||||||
|
self._done_sending_count[req_id] += 1
|
||||||
|
for req_id in done_recving:
|
||||||
|
self._done_recving_count[req_id] += 1
|
||||||
|
|
||||||
|
# Keep track of how many other ranks have finished.
|
||||||
|
other_ranks_finished_ids: list[str] = []
|
||||||
|
for i in range(1, self.world_size):
|
||||||
|
other_ranks_finished_ids.extend(
|
||||||
|
self.tp_group.recv_object(src=i))
|
||||||
|
for req_id in other_ranks_finished_ids:
|
||||||
|
if (req_id in self._done_recving_count
|
||||||
|
or req_id in self._recving_transfers):
|
||||||
|
self._done_recving_count[req_id] += 1
|
||||||
|
else:
|
||||||
|
self._done_sending_count[req_id] += 1
|
||||||
|
|
||||||
|
# Return ids that finished on all ranks to the scheduler.
|
||||||
|
all_done_recving: set[str] = set()
|
||||||
|
for req_id in list(self._done_recving_count.keys()):
|
||||||
|
if self._done_recving_count[req_id] == self.world_size:
|
||||||
|
del self._done_recving_count[req_id]
|
||||||
|
all_done_recving.add(req_id)
|
||||||
|
|
||||||
|
all_done_sending: set[str] = set()
|
||||||
|
for req_id in list(self._done_sending_count.keys()):
|
||||||
|
if self._done_sending_count[req_id] == self.world_size:
|
||||||
|
del self._done_sending_count[req_id]
|
||||||
|
all_done_sending.add(req_id)
|
||||||
|
|
||||||
|
return all_done_sending, all_done_recving
|
||||||
|
|
||||||
|
# Ranks 1 to N-1: send finished ids to Rank 0.
|
||||||
|
else:
|
||||||
|
finished_req_ids = list(done_recving.union(done_sending))
|
||||||
|
self.tp_group.send_object(finished_req_ids, dst=0)
|
||||||
|
|
||||||
|
# Unused as only Rank 0 results are sent to scheduler.
|
||||||
|
return done_sending, done_recving
|
||||||
|
|
||||||
|
def _get_new_notifs(self) -> set[str]:
|
||||||
|
"""Get req_ids which got a remote xfer message."""
|
||||||
|
|
||||||
|
notified_req_ids: set[str] = set()
|
||||||
|
for req_ids in self.nixl_wrapper.get_new_notifs().values():
|
||||||
|
for req_id in req_ids:
|
||||||
|
assert req_id not in notified_req_ids
|
||||||
|
notified_req_ids.add(req_id.decode("utf-8"))
|
||||||
|
return notified_req_ids
|
||||||
|
|
||||||
|
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
|
||||||
|
"""
|
||||||
|
Pop completed xfers by checking for DONE state.
|
||||||
|
Args:
|
||||||
|
transfers: dict of req_id -> list[running_xfer]
|
||||||
|
Returns:
|
||||||
|
set of req_ids that have all done xfers
|
||||||
|
"""
|
||||||
|
done_req_ids: set[str] = set()
|
||||||
|
for req_id, handles in list(transfers.items()):
|
||||||
|
running_reqs = []
|
||||||
|
for handle in handles:
|
||||||
|
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||||
|
if xfer_state == "DONE":
|
||||||
|
# TODO ptarasiewicz: why abort is throwing errors?
|
||||||
|
# self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
|
continue
|
||||||
|
if xfer_state == "PROC":
|
||||||
|
running_reqs.append(handle)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Transfer failed with state %s",
|
||||||
|
xfer_state)
|
||||||
|
if len(running_reqs) == 0:
|
||||||
|
done_req_ids.add(req_id)
|
||||||
|
del transfers[req_id]
|
||||||
|
else:
|
||||||
|
transfers[req_id] = running_reqs
|
||||||
|
return done_req_ids
|
||||||
|
|
||||||
|
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||||
|
"""
|
||||||
|
Start loading by triggering non-blocking nixl_xfer.
|
||||||
|
We check for these trnxs to complete in each step().
|
||||||
|
"""
|
||||||
|
for req_id, meta in metadata.requests.items():
|
||||||
|
logger.debug(
|
||||||
|
"start_load_kv for request %s from remote engine %s. "
|
||||||
|
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
||||||
|
meta.remote_engine_id, len(meta.local_block_ids),
|
||||||
|
len(meta.remote_block_ids))
|
||||||
|
self._read_blocks(
|
||||||
|
request_id=req_id,
|
||||||
|
dst_engine_id=meta.remote_engine_id,
|
||||||
|
local_block_ids=meta.local_block_ids,
|
||||||
|
remote_block_ids=meta.remote_block_ids,
|
||||||
|
remote_host=meta.remote_host,
|
||||||
|
remote_port=meta.remote_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_blocks(
|
||||||
|
self,
|
||||||
|
local_block_ids: list[int],
|
||||||
|
remote_block_ids: list[int],
|
||||||
|
remote_host: str,
|
||||||
|
remote_port: int,
|
||||||
|
dst_engine_id: str,
|
||||||
|
request_id: str,
|
||||||
|
):
|
||||||
|
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
|
||||||
|
if dst_engine_id not in self._remote_agents:
|
||||||
|
self._nixl_handshake(remote_host, remote_port)
|
||||||
|
|
||||||
|
# NOTE(rob): having the staging blocks be on the READER side is
|
||||||
|
# not going to work well (since we will have to call rearrange tensors).
|
||||||
|
# after we detect the txn is complete (which means we cannot make the
|
||||||
|
# read trxn async easily). If we want to make "READ" happen cleanly,
|
||||||
|
# then we will need to have the staging blocks on the remote side.
|
||||||
|
|
||||||
|
# NOTE(rob): according to nvidia the staging blocks are used to
|
||||||
|
# saturate IB with heterogeneous TP sizes. We should remove the staging
|
||||||
|
# blocks until we are ready.
|
||||||
|
|
||||||
|
# Full prefix cache hit: do not need to read remote blocks,
|
||||||
|
# just notify P worker that we have the blocks we need.
|
||||||
|
num_local_blocks = len(local_block_ids)
|
||||||
|
if num_local_blocks == 0:
|
||||||
|
self.nixl_wrapper.send_notif(dst_engine_id,
|
||||||
|
notif_msg=request_id.encode("utf-8"))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Partial prefix cache hit: just read uncomputed blocks.
|
||||||
|
num_remote_blocks = len(remote_block_ids)
|
||||||
|
assert num_local_blocks <= num_remote_blocks
|
||||||
|
if num_local_blocks < num_remote_blocks:
|
||||||
|
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||||
|
|
||||||
|
# Get side handles.
|
||||||
|
local_xfer_side_handle = self.src_xfer_side_handle
|
||||||
|
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||||
|
|
||||||
|
# Get descs ids.
|
||||||
|
remote_block_descs_ids = self._get_block_descs_ids(
|
||||||
|
dst_engine_id, remote_block_ids)
|
||||||
|
local_block_descs_ids = self._get_block_descs_ids(
|
||||||
|
self.engine_id, local_block_ids)
|
||||||
|
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||||
|
|
||||||
|
# Prepare transfer with Nixl.
|
||||||
|
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||||
|
"READ",
|
||||||
|
local_xfer_side_handle,
|
||||||
|
local_block_descs_ids,
|
||||||
|
remote_xfer_side_handle,
|
||||||
|
remote_block_descs_ids,
|
||||||
|
notif_msg=request_id.encode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Begin async xfer.
|
||||||
|
self.nixl_wrapper.transfer(handle)
|
||||||
|
|
||||||
|
# Use handle to check completion in future step().
|
||||||
|
self._recving_transfers[request_id].append(handle)
|
||||||
|
|
||||||
|
def _get_block_descs_ids(self, engine_id: str,
|
||||||
|
block_ids: list[int]) -> list[int]:
|
||||||
|
"""Get the descs ids for a set of block ids."""
|
||||||
|
|
||||||
|
# range(1) for MLA, range(2) otherwise.
|
||||||
|
region_ids = range(self.num_regions)
|
||||||
|
num_blocks = self.dst_num_blocks[engine_id]
|
||||||
|
|
||||||
|
# Compute the desc ids for each block.
|
||||||
|
descs_ids: list[int] = []
|
||||||
|
for reg_id in region_ids:
|
||||||
|
for block_id in block_ids:
|
||||||
|
descs_ids.append(reg_id * num_blocks + block_id)
|
||||||
|
return descs_ids
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||||
|
"""Context manager for a ZMQ socket"""
|
||||||
|
|
||||||
|
ctx: Optional[zmq.Context] = None
|
||||||
|
try:
|
||||||
|
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
if socket_type == zmq.ROUTER:
|
||||||
|
socket = ctx.socket(zmq.ROUTER)
|
||||||
|
socket.bind(addr)
|
||||||
|
elif socket_type == zmq.REQ:
|
||||||
|
socket = ctx.socket(zmq.REQ)
|
||||||
|
socket.connect(addr)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||||
|
|
||||||
|
yield socket
|
||||||
|
finally:
|
||||||
|
if ctx is not None:
|
||||||
|
ctx.destroy(linger=0)
|
||||||
@ -17,6 +17,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -132,8 +133,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||||
|
|
||||||
# Get the metadata
|
# Get the metadata
|
||||||
metadata: KVConnectorMetadata = \
|
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||||
self._get_connector_metadata()
|
|
||||||
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
@ -225,7 +225,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
) -> int:
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Get number of new tokens that can be loaded from the
|
Get number of new tokens that can be loaded from the
|
||||||
external KV cache beyond the num_computed_tokens.
|
external KV cache beyond the num_computed_tokens.
|
||||||
@ -239,7 +239,6 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
the number of tokens that can be loaded from the
|
the number of tokens that can be loaded from the
|
||||||
external KV cache beyond what is already computed.
|
external KV cache beyond what is already computed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# NOTE: in this debug implementation, we assume that the prompt is
|
# NOTE: in this debug implementation, we assume that the prompt is
|
||||||
# cached_prompt + newly_generated_single_token
|
# cached_prompt + newly_generated_single_token
|
||||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||||
@ -248,7 +247,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
# with the block granularity. And it expects the returned blocks and
|
# with the block granularity. And it expects the returned blocks and
|
||||||
# num_computed_tokens to also be aligned with the block granularity.
|
# num_computed_tokens to also be aligned with the block granularity.
|
||||||
if not self._found_match_for_request(request):
|
if not self._found_match_for_request(request):
|
||||||
return 0
|
return 0, False
|
||||||
|
|
||||||
logger.info("External Cache Hit!")
|
logger.info("External Cache Hit!")
|
||||||
|
|
||||||
@ -257,9 +256,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
num_tokens_to_check = align_to_block_size(
|
num_tokens_to_check = align_to_block_size(
|
||||||
len(request.prompt_token_ids) - 1, self._block_size)
|
len(request.prompt_token_ids) - 1, self._block_size)
|
||||||
|
|
||||||
return num_tokens_to_check - num_computed_tokens
|
return num_tokens_to_check - num_computed_tokens, False
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
num_external_tokens: int):
|
num_external_tokens: int):
|
||||||
"""
|
"""
|
||||||
Update KVConnector state after block allocation.
|
Update KVConnector state after block allocation.
|
||||||
|
|||||||
@ -403,6 +403,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"access by 3rd parties, and long enough to be "
|
"access by 3rd parties, and long enough to be "
|
||||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||||
"to 256 bit). Not supported by vLLM engine V0."))
|
"to 256 bit). Not supported by vLLM engine V0."))
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="KVTransfer parameters used for disaggregated serving.")
|
||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
@ -540,7 +543,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
else RequestOutputKind.FINAL_ONLY,
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
guided_decoding=guided_decoding,
|
guided_decoding=guided_decoding,
|
||||||
logit_bias=self.logit_bias)
|
logit_bias=self.logit_bias,
|
||||||
|
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
||||||
|
if self.kv_transfer_params else None))
|
||||||
|
|
||||||
def _get_guided_json_from_tool(
|
def _get_guided_json_from_tool(
|
||||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||||
@ -848,6 +853,10 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||||
"that are not JSON-encodable can be identified."))
|
"that are not JSON-encodable can be identified."))
|
||||||
|
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="KVTransfer parameters used for disaggregated serving.")
|
||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
# Default sampling parameters for completion requests
|
# Default sampling parameters for completion requests
|
||||||
@ -973,7 +982,9 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
else RequestOutputKind.FINAL_ONLY,
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
guided_decoding=guided_decoding,
|
guided_decoding=guided_decoding,
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
allowed_token_ids=self.allowed_token_ids)
|
allowed_token_ids=self.allowed_token_ids,
|
||||||
|
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
||||||
|
if self.kv_transfer_params else None))
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1223,6 +1234,8 @@ class CompletionResponse(OpenAIBaseModel):
|
|||||||
model: str
|
model: str
|
||||||
choices: list[CompletionResponseChoice]
|
choices: list[CompletionResponseChoice]
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None, description="KVTransfer parameters.")
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
@ -1412,6 +1425,8 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
|||||||
choices: list[ChatCompletionResponseChoice]
|
choices: list[ChatCompletionResponseChoice]
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None, description="KVTransfer parameters.")
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(OpenAIBaseModel):
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
|
|||||||
@ -1086,6 +1086,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||||
|
kv_transfer_params=final_res.kv_transfer_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -482,7 +482,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
kv_transfer_params=final_res_batch[0].kv_transfer_params)
|
||||||
|
|
||||||
def _create_completion_logprobs(
|
def _create_completion_logprobs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
10
vllm/envs.py
10
vllm/envs.py
@ -112,6 +112,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -747,6 +749,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# insecure method and it is needed for some reason.
|
# insecure method and it is needed for some reason.
|
||||||
"VLLM_ALLOW_INSECURE_SERIALIZATION":
|
"VLLM_ALLOW_INSECURE_SERIALIZATION":
|
||||||
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
|
lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))),
|
||||||
|
|
||||||
|
# IP address used for NIXL handshake between remote agents.
|
||||||
|
"VLLM_NIXL_SIDE_CHANNEL_HOST":
|
||||||
|
lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"),
|
||||||
|
|
||||||
|
# Port used for NIXL handshake between remote agents.
|
||||||
|
"VLLM_NIXL_SIDE_CHANNEL_PORT":
|
||||||
|
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -11,10 +11,6 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
||||||
has_kv_transfer_group,
|
|
||||||
is_v1_kv_transfer_group)
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -106,16 +102,6 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
dp_metadata=dp_metadata)
|
dp_metadata=dp_metadata)
|
||||||
|
|
||||||
# KVConnector: trigger (possibly async) load before forward.
|
|
||||||
# Each attn layer will block until the reading is complete.
|
|
||||||
trigger_kv_transfer = (attn_metadata is not None
|
|
||||||
and has_kv_transfer_group()
|
|
||||||
and is_v1_kv_transfer_group())
|
|
||||||
if trigger_kv_transfer:
|
|
||||||
kv_connector = get_kv_transfer_group()
|
|
||||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
|
||||||
kv_connector.start_load_kv(_forward_context)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -152,11 +138,4 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
"(batchsize, count, median_time(ms)): %s"),
|
"(batchsize, count, median_time(ms)): %s"),
|
||||||
forward_stats)
|
forward_stats)
|
||||||
|
|
||||||
# KVConnector: each attn layer triggers (possibly async) save.
|
|
||||||
# Ensure all those operations complete before forward() is done.
|
|
||||||
if trigger_kv_transfer:
|
|
||||||
kv_connector = get_kv_transfer_group()
|
|
||||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
|
||||||
kv_connector.wait_for_save()
|
|
||||||
|
|
||||||
_forward_context = prev_context
|
_forward_context = prev_context
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import time
|
|||||||
from collections.abc import MutableSequence
|
from collections.abc import MutableSequence
|
||||||
from collections.abc import Sequence as GenericSequence
|
from collections.abc import Sequence as GenericSequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, Optional, Union
|
from typing import Any, Generic, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeVar, deprecated
|
from typing_extensions import TypeVar, deprecated
|
||||||
@ -103,6 +103,7 @@ class RequestOutput:
|
|||||||
encoder_prompt_token_ids: The token IDs of the encoder prompt.
|
encoder_prompt_token_ids: The token IDs of the encoder prompt.
|
||||||
None if decoder-only.
|
None if decoder-only.
|
||||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||||
|
kv_transfer_params: The params for remote K/V transfer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -120,6 +121,7 @@ class RequestOutput:
|
|||||||
num_cached_tokens: Optional[int] = None,
|
num_cached_tokens: Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
|
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -133,11 +135,13 @@ class RequestOutput:
|
|||||||
self.encoder_prompt = encoder_prompt
|
self.encoder_prompt = encoder_prompt
|
||||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||||
self.num_cached_tokens = num_cached_tokens
|
self.num_cached_tokens = num_cached_tokens
|
||||||
|
self.kv_transfer_params = kv_transfer_params
|
||||||
|
|
||||||
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
|
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
|
||||||
"""Merge subsequent RequestOutput into this one"""
|
"""Merge subsequent RequestOutput into this one"""
|
||||||
|
|
||||||
self.finished |= next_output.finished
|
self.finished |= next_output.finished
|
||||||
|
self.kv_transfer_params = next_output.kv_transfer_params
|
||||||
|
|
||||||
for next_completion in next_output.outputs:
|
for next_completion in next_output.outputs:
|
||||||
for i, completion in enumerate(self.outputs):
|
for i, completion in enumerate(self.outputs):
|
||||||
|
|||||||
@ -36,6 +36,12 @@ class KVCacheBlocks:
|
|||||||
"""Converts the KVCacheBlocks instance to a list of block IDs."""
|
"""Converts the KVCacheBlocks instance to a list of block IDs."""
|
||||||
return [block.block_id for block in self.blocks]
|
return [block.block_id for block in self.blocks]
|
||||||
|
|
||||||
|
def get_unhashed_block_ids(self) -> list[int]:
|
||||||
|
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||||
|
return [
|
||||||
|
block.block_id for block in self.blocks if block.block_hash is None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class KVCacheManager:
|
class KVCacheManager:
|
||||||
|
|
||||||
@ -116,6 +122,12 @@ class KVCacheManager:
|
|||||||
- The number of computed tokens.
|
- The number of computed tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Request already has blocks from async load via KVConnector.
|
||||||
|
num_existing_blocks = len(
|
||||||
|
self.single_type_manager.req_to_blocks[request.request_id])
|
||||||
|
if num_existing_blocks > 0:
|
||||||
|
return KVCacheBlocks.create_empty(), request.num_computed_tokens
|
||||||
|
|
||||||
# Prefix caching is disabled or
|
# Prefix caching is disabled or
|
||||||
# When the request requires prompt logprobs, we skip prefix caching.
|
# When the request requires prompt logprobs, we skip prefix caching.
|
||||||
if (not self.enable_caching
|
if (not self.enable_caching
|
||||||
@ -173,6 +185,7 @@ class KVCacheManager:
|
|||||||
num_new_tokens: int,
|
num_new_tokens: int,
|
||||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||||
num_lookahead_tokens: int = 0,
|
num_lookahead_tokens: int = 0,
|
||||||
|
delay_cache_blocks: bool = False,
|
||||||
) -> Optional[KVCacheBlocks]:
|
) -> Optional[KVCacheBlocks]:
|
||||||
"""Add slots for a request with new tokens to append.
|
"""Add slots for a request with new tokens to append.
|
||||||
|
|
||||||
@ -186,6 +199,9 @@ class KVCacheManager:
|
|||||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||||
This is used by spec decode proposers with kv-cache such
|
This is used by spec decode proposers with kv-cache such
|
||||||
as eagle.
|
as eagle.
|
||||||
|
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||||
|
used by P/D when allocating blocks used in a KV transfer
|
||||||
|
which will complete in a future step.
|
||||||
|
|
||||||
Blocks layout:
|
Blocks layout:
|
||||||
```
|
```
|
||||||
@ -255,7 +271,9 @@ class KVCacheManager:
|
|||||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||||
request.request_id, num_tokens_need_slot)
|
request.request_id, num_tokens_need_slot)
|
||||||
|
|
||||||
if not self.enable_caching:
|
# P/D: delay caching blocks if we have to recv from
|
||||||
|
# remote. Update state for locally cached blocks.
|
||||||
|
if not self.enable_caching or delay_cache_blocks:
|
||||||
return KVCacheBlocks(new_blocks)
|
return KVCacheBlocks(new_blocks)
|
||||||
|
|
||||||
# Speculated tokens might be rejected in the future, so we does
|
# Speculated tokens might be rejected in the future, so we does
|
||||||
@ -350,3 +368,16 @@ class KVCacheManager:
|
|||||||
A list of KV cache events.
|
A list of KV cache events.
|
||||||
"""
|
"""
|
||||||
return self.block_pool.take_events()
|
return self.block_pool.take_events()
|
||||||
|
|
||||||
|
def get_block_ids(self, request_id: str) -> list[int]:
|
||||||
|
"""Get the block ids of a request."""
|
||||||
|
assert request_id in self.single_type_manager.req_to_blocks
|
||||||
|
return [
|
||||||
|
block.block_id
|
||||||
|
for block in self.single_type_manager.req_to_blocks[request_id]
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_num_blocks(self, request_id: str):
|
||||||
|
"""Get the number of blocks."""
|
||||||
|
assert request_id in self.single_type_manager.req_to_blocks
|
||||||
|
return len(self.single_type_manager.req_to_blocks[request_id])
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from collections.abc import Iterable
|
|||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.engine import EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreOutputs
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
@ -137,3 +138,6 @@ class SchedulerInterface(ABC):
|
|||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the scheduler."""
|
"""Shutdown the scheduler."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
|
||||||
|
return None
|
||||||
|
|||||||
@ -5,13 +5,15 @@ from __future__ import annotations
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||||
|
KVConnectorRole,
|
||||||
|
KVTransferParams)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||||
@ -96,6 +98,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
# This is flushed at the end of each scheduling step.
|
# This is flushed at the end of each scheduling step.
|
||||||
self.finished_req_ids: set[str] = set()
|
self.finished_req_ids: set[str] = set()
|
||||||
|
|
||||||
|
# P/D: requests in process of recving KV transfers
|
||||||
|
self.finished_recving_kv_req_ids: set[str] = set()
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
# them at each scheduling step.
|
# them at each scheduling step.
|
||||||
# Request id -> deque of CachedRequestData
|
# Request id -> deque of CachedRequestData
|
||||||
@ -307,6 +312,16 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
request = self.waiting[0]
|
request = self.waiting[0]
|
||||||
|
|
||||||
|
# P/D: skip request if still waiting for remote kvs.
|
||||||
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||||
|
is_ready = self._update_waiting_for_remote_kv(request)
|
||||||
|
if is_ready:
|
||||||
|
request.status = RequestStatus.WAITING
|
||||||
|
else:
|
||||||
|
self.waiting.popleft()
|
||||||
|
skipped_waiting_requests.appendleft(request)
|
||||||
|
continue
|
||||||
|
|
||||||
# Skip request if the structured output request is still waiting
|
# Skip request if the structured output request is still waiting
|
||||||
# for FSM compilation.
|
# for FSM compilation.
|
||||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||||
@ -330,49 +345,55 @@ class Scheduler(SchedulerInterface):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Get already-cached tokens.
|
# Get already-cached tokens.
|
||||||
computed_blocks, num_computed_tokens = \
|
new_computed_blocks, num_computed_tokens = \
|
||||||
self.kv_cache_manager.get_computed_blocks(
|
self.kv_cache_manager.get_computed_blocks(
|
||||||
request)
|
request)
|
||||||
|
|
||||||
# Get externally-cached tokens if using a KVConnector.
|
# Get externally-cached tokens if using a KVConnector.
|
||||||
num_external_tokens = (
|
num_external_tokens, load_kv_async = (
|
||||||
0 if self.connector is None else
|
(0, False) if self.connector is None else
|
||||||
self.connector.get_num_new_matched_tokens(
|
self.connector.get_num_new_matched_tokens(
|
||||||
request, num_computed_tokens))
|
request, num_computed_tokens))
|
||||||
|
|
||||||
# Total computed tokens (local + external).
|
# Total computed tokens (local + external).
|
||||||
num_computed_tokens += num_external_tokens
|
num_computed_tokens += num_external_tokens
|
||||||
|
|
||||||
# Number of tokens to be scheduled.
|
encoder_inputs_to_schedule = None
|
||||||
# We use `request.num_tokens` instead of
|
new_encoder_budget = encoder_budget
|
||||||
# `request.num_prompt_tokens` to consider the resumed requests,
|
|
||||||
# which have output tokens.
|
|
||||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
|
||||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
|
||||||
num_new_tokens):
|
|
||||||
num_new_tokens = (
|
|
||||||
self.scheduler_config.long_prefill_token_threshold)
|
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
|
||||||
assert num_new_tokens > 0
|
|
||||||
|
|
||||||
# Schedule encoder inputs.
|
# P/D: loading remote KV, do not allocate for new work.
|
||||||
if request.has_encoder_inputs:
|
if load_kv_async:
|
||||||
(encoder_inputs_to_schedule, num_new_tokens,
|
num_new_tokens = 0
|
||||||
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
# Number of tokens to be scheduled.
|
||||||
request, num_computed_tokens, num_new_tokens,
|
|
||||||
encoder_budget)
|
|
||||||
if num_new_tokens == 0:
|
|
||||||
# The request cannot be scheduled.
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
encoder_inputs_to_schedule = None
|
# We use `request.num_tokens` instead of
|
||||||
new_encoder_budget = encoder_budget
|
# `request.num_prompt_tokens` to consider the resumed
|
||||||
|
# requests, which have output tokens.
|
||||||
|
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||||
|
if (0 < self.scheduler_config.long_prefill_token_threshold
|
||||||
|
< num_new_tokens):
|
||||||
|
num_new_tokens = (
|
||||||
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
|
assert num_new_tokens > 0
|
||||||
|
|
||||||
|
# Schedule encoder inputs.
|
||||||
|
if request.has_encoder_inputs:
|
||||||
|
(encoder_inputs_to_schedule, num_new_tokens,
|
||||||
|
new_encoder_budget
|
||||||
|
) = self._try_schedule_encoder_inputs(
|
||||||
|
request, num_computed_tokens, num_new_tokens,
|
||||||
|
encoder_budget)
|
||||||
|
if num_new_tokens == 0:
|
||||||
|
# The request cannot be scheduled.
|
||||||
|
break
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens + num_external_tokens,
|
num_new_tokens + num_external_tokens,
|
||||||
computed_blocks,
|
new_computed_blocks,
|
||||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||||
|
delay_cache_blocks=load_kv_async,
|
||||||
)
|
)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
@ -384,10 +405,18 @@ class Scheduler(SchedulerInterface):
|
|||||||
if self.connector is not None:
|
if self.connector is not None:
|
||||||
self.connector.update_state_after_alloc(
|
self.connector.update_state_after_alloc(
|
||||||
request,
|
request,
|
||||||
|
new_computed_blocks + new_blocks,
|
||||||
num_external_tokens,
|
num_external_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
|
if load_kv_async:
|
||||||
|
# If loading async, allocate memory and put request
|
||||||
|
# into the WAITING_FOR_REMOTE_KV state.
|
||||||
|
skipped_waiting_requests.appendleft(request)
|
||||||
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
continue
|
||||||
|
|
||||||
if request.use_structured_output:
|
if request.use_structured_output:
|
||||||
structured_output_request_ids[
|
structured_output_request_ids[
|
||||||
request.request_id] = req_index
|
request.request_id] = req_index
|
||||||
@ -407,7 +436,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
if self.lora_config and request.lora_request:
|
if self.lora_config and request.lora_request:
|
||||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
req_to_new_block_ids[request.request_id] = (
|
req_to_new_block_ids[request.request_id] = (
|
||||||
computed_blocks + new_blocks).get_block_ids()
|
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
@ -698,6 +727,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
stopped = False
|
stopped = False
|
||||||
new_logprobs = None
|
new_logprobs = None
|
||||||
new_token_ids = generated_token_ids
|
new_token_ids = generated_token_ids
|
||||||
|
kv_transfer_params = None
|
||||||
|
|
||||||
# Append generated tokens and check for stop. Note that if
|
# Append generated tokens and check for stop. Note that if
|
||||||
# a request is still being prefilled, we expect the model runner
|
# a request is still being prefilled, we expect the model runner
|
||||||
@ -709,7 +739,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# This must be called before we make the EngineCoreOutput.
|
# This must be called before we make the EngineCoreOutput.
|
||||||
stopped = check_stop(request, self.max_model_len)
|
stopped = check_stop(request, self.max_model_len)
|
||||||
if stopped:
|
if stopped:
|
||||||
self._free_request(request)
|
kv_transfer_params = self._free_request(request)
|
||||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -739,7 +769,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Get prompt logprobs for this request.
|
# Get prompt logprobs for this request.
|
||||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||||
if new_token_ids:
|
if new_token_ids or kv_transfer_params:
|
||||||
|
|
||||||
# Add EngineCoreOutput for this Request.
|
# Add EngineCoreOutput for this Request.
|
||||||
outputs.append(
|
outputs.append(
|
||||||
EngineCoreOutput(
|
EngineCoreOutput(
|
||||||
@ -749,7 +780,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
new_logprobs=new_logprobs,
|
new_logprobs=new_logprobs,
|
||||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||||
stop_reason=request.stop_reason,
|
stop_reason=request.stop_reason,
|
||||||
events=request.take_events()))
|
events=request.take_events(),
|
||||||
|
kv_transfer_params=kv_transfer_params,
|
||||||
|
))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Invariant: EngineCore returns no partial prefill outputs.
|
# Invariant: EngineCore returns no partial prefill outputs.
|
||||||
assert not prompt_logprobs_tensors
|
assert not prompt_logprobs_tensors
|
||||||
@ -757,6 +791,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
if not stopped:
|
if not stopped:
|
||||||
new_running.append(request)
|
new_running.append(request)
|
||||||
|
|
||||||
|
# P/D: update state for finished KV Transfers.
|
||||||
|
self._update_from_kv_xfer_finished(model_runner_output)
|
||||||
|
|
||||||
# Return the cached request data to the queue so they can be reused.
|
# Return the cached request data to the queue so they can be reused.
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||||
@ -811,15 +848,27 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.status = finished_status
|
request.status = finished_status
|
||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
|
|
||||||
def _free_request(self, request: Request) -> None:
|
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
|
||||||
|
|
||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
self.kv_cache_manager.free(request)
|
|
||||||
self.kv_cache_manager.free_block_hashes(request)
|
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||||
self.encoder_cache_manager.free(request)
|
self.encoder_cache_manager.free(request)
|
||||||
self._cached_reqs_data.pop(request.request_id, None)
|
self._cached_reqs_data.pop(request.request_id, None)
|
||||||
del self.requests[request.request_id]
|
|
||||||
self.finished_req_ids.add(request.request_id)
|
self.finished_req_ids.add(request.request_id)
|
||||||
|
|
||||||
|
if not delay_free_blocks:
|
||||||
|
self._free_blocks(request)
|
||||||
|
|
||||||
|
return kv_xfer_params
|
||||||
|
|
||||||
|
def _free_blocks(self, request: Request):
|
||||||
|
assert request.is_finished()
|
||||||
|
assert request.request_id not in self._cached_reqs_data
|
||||||
|
self.kv_cache_manager.free(request)
|
||||||
|
self.kv_cache_manager.free_block_hashes(request)
|
||||||
|
del self.requests[request.request_id]
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
return len(self.waiting) + len(self.running)
|
return len(self.waiting) + len(self.running)
|
||||||
|
|
||||||
@ -863,3 +912,70 @@ class Scheduler(SchedulerInterface):
|
|||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
if self.kv_event_publisher:
|
if self.kv_event_publisher:
|
||||||
self.kv_event_publisher.shutdown()
|
self.kv_event_publisher.shutdown()
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
# P/D Related Methods
|
||||||
|
########################################################################
|
||||||
|
|
||||||
|
def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
|
||||||
|
return self.connector
|
||||||
|
|
||||||
|
def _connector_finished(
|
||||||
|
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
|
||||||
|
"""Invoke the KV connector request_finished() method if applicable."""
|
||||||
|
if self.connector is None:
|
||||||
|
return False, None
|
||||||
|
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
|
return self.connector.request_finished(request, block_ids)
|
||||||
|
|
||||||
|
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
P/D: check if the request_id is finished_recving.
|
||||||
|
|
||||||
|
The finished_recving_kv_req_ids list is populated
|
||||||
|
on the previous steps()'s update_from_output based
|
||||||
|
on the worker side connector.
|
||||||
|
|
||||||
|
When the kv transfer is ready, we cache the blocks
|
||||||
|
and the request state will be moved back to WAITING from
|
||||||
|
WAITING_FOR_REMOTE_KV.
|
||||||
|
"""
|
||||||
|
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Now that the blocks are ready, actually cache them.
|
||||||
|
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
|
num_computed_tokens = len(block_ids) * self.block_size
|
||||||
|
if num_computed_tokens == request.num_tokens:
|
||||||
|
num_computed_tokens -= 1
|
||||||
|
self.kv_cache_manager.single_type_manager.cache_blocks(
|
||||||
|
request,
|
||||||
|
self.kv_cache_manager.req_to_block_hashes[request.request_id],
|
||||||
|
num_computed_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the request state for scheduling.
|
||||||
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
|
||||||
|
# Return that we are ready.
|
||||||
|
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _update_from_kv_xfer_finished(self,
|
||||||
|
model_runner_output: ModelRunnerOutput):
|
||||||
|
"""
|
||||||
|
P/D: update the scheduler state based on the output.
|
||||||
|
|
||||||
|
The Worker side connectors add finished_recving and
|
||||||
|
finished_sending reqs to the output.
|
||||||
|
* if finished_sending: free the blocks
|
||||||
|
# if finished_recving: add to state so we can
|
||||||
|
scheduler the request during the next step.
|
||||||
|
"""
|
||||||
|
# P/D: update recv and send status from last step.
|
||||||
|
for req_id in (model_runner_output.finished_recving or ()):
|
||||||
|
logger.debug("Finished recving KV transfer for request %s", req_id)
|
||||||
|
self.finished_recving_kv_req_ids.add(req_id)
|
||||||
|
for req_id in (model_runner_output.finished_sending or ()):
|
||||||
|
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||||
|
self._free_blocks(self.requests[req_id])
|
||||||
|
|||||||
@ -105,6 +105,7 @@ class EngineCoreOutput(
|
|||||||
finish_reason: Optional[FinishReason] = None
|
finish_reason: Optional[FinishReason] = None
|
||||||
stop_reason: Union[int, str, None] = None
|
stop_reason: Union[int, str, None] = None
|
||||||
events: Optional[list[EngineCoreEvent]] = None
|
events: Optional[list[EngineCoreEvent]] = None
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
|
|||||||
@ -182,6 +182,15 @@ class EngineCore:
|
|||||||
# Start grammar compilation asynchronously
|
# Start grammar compilation asynchronously
|
||||||
self.structured_output_manager.grammar_init(req)
|
self.structured_output_manager.grammar_init(req)
|
||||||
|
|
||||||
|
if req.raw_kv_transfer_params is not None:
|
||||||
|
if (kv_connector := self.scheduler.get_kv_connector()):
|
||||||
|
# Parse raw KV transfer params via connector.
|
||||||
|
kv_connector.set_kv_transfer_params(req)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Got KVTransferParams, but no KVConnector found. "
|
||||||
|
"Disabling KVTransfer for this request.")
|
||||||
|
|
||||||
self.scheduler.add_request(req)
|
self.scheduler.add_request(req)
|
||||||
|
|
||||||
def abort_requests(self, request_ids: list[str]):
|
def abort_requests(self, request_ids: list[str]):
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
@ -146,6 +146,7 @@ class RequestState:
|
|||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
finish_reason: Optional[FinishReason],
|
finish_reason: Optional[FinishReason],
|
||||||
stop_reason: Union[int, str, None],
|
stop_reason: Union[int, str, None],
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
) -> Optional[RequestOutput]:
|
) -> Optional[RequestOutput]:
|
||||||
|
|
||||||
finished = finish_reason is not None
|
finished = finish_reason is not None
|
||||||
@ -167,13 +168,15 @@ class RequestState:
|
|||||||
if not outputs:
|
if not outputs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._new_request_output(request_id, outputs, finished)
|
return self._new_request_output(request_id, outputs, finished,
|
||||||
|
kv_transfer_params)
|
||||||
|
|
||||||
def _new_request_output(
|
def _new_request_output(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
outputs: list[CompletionOutput],
|
outputs: list[CompletionOutput],
|
||||||
finished: bool,
|
finished: bool,
|
||||||
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
) -> RequestOutput:
|
) -> RequestOutput:
|
||||||
|
|
||||||
if self.output_kind == RequestOutputKind.DELTA:
|
if self.output_kind == RequestOutputKind.DELTA:
|
||||||
@ -189,6 +192,7 @@ class RequestState:
|
|||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
finished=finished,
|
finished=finished,
|
||||||
|
kv_transfer_params=kv_transfer_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _new_completion_output(
|
def _new_completion_output(
|
||||||
@ -335,6 +339,7 @@ class OutputProcessor:
|
|||||||
new_token_ids = engine_core_output.new_token_ids
|
new_token_ids = engine_core_output.new_token_ids
|
||||||
finish_reason = engine_core_output.finish_reason
|
finish_reason = engine_core_output.finish_reason
|
||||||
stop_reason = engine_core_output.stop_reason
|
stop_reason = engine_core_output.stop_reason
|
||||||
|
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||||
|
|
||||||
req_state.is_prefilling = False
|
req_state.is_prefilling = False
|
||||||
|
|
||||||
@ -350,7 +355,8 @@ class OutputProcessor:
|
|||||||
|
|
||||||
# 4) Create and handle RequestOutput objects.
|
# 4) Create and handle RequestOutput objects.
|
||||||
if request_output := req_state.make_request_output(
|
if request_output := req_state.make_request_output(
|
||||||
new_token_ids, finish_reason, stop_reason):
|
new_token_ids, finish_reason, stop_reason,
|
||||||
|
kv_transfer_params):
|
||||||
if req_state.queue is not None:
|
if req_state.queue is not None:
|
||||||
# AsyncLLM: put into queue for handling by generate().
|
# AsyncLLM: put into queue for handling by generate().
|
||||||
req_state.queue.put(request_output)
|
req_state.queue.put(request_output)
|
||||||
|
|||||||
@ -100,12 +100,16 @@ class ModelRunnerOutput:
|
|||||||
# [prompt_len]
|
# [prompt_len]
|
||||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
||||||
|
|
||||||
|
# [req_ids]
|
||||||
|
finished_sending: Optional[set[str]] = None
|
||||||
|
finished_recving: Optional[set[str]] = None
|
||||||
|
|
||||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
|
||||||
req_ids=[],
|
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||||
req_id_to_index={},
|
req_id_to_index={},
|
||||||
sampled_token_ids=[],
|
sampled_token_ids=[],
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
)
|
finished_sending=None,
|
||||||
|
finished_recving=None)
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
@ -61,6 +62,15 @@ class Request:
|
|||||||
self.num_encoder_inputs = len(self.mm_inputs)
|
self.num_encoder_inputs = len(self.mm_inputs)
|
||||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||||
|
|
||||||
|
# P/D: KV transfer parameters (raw and parsed).
|
||||||
|
raw_params = (None if sampling_params.extra_args is None
|
||||||
|
else sampling_params.extra_args.get(
|
||||||
|
"kv_transfer_params", None))
|
||||||
|
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
|
||||||
|
# Each connector parses the raw dictionary and sets this
|
||||||
|
# attr the first time that the request is processed.
|
||||||
|
self.kv_transfer_params: Optional[KVTransferParams] = None
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||||
if self.mm_hashes:
|
if self.mm_hashes:
|
||||||
@ -150,6 +160,7 @@ class RequestStatus(enum.IntEnum):
|
|||||||
"""Status of a request."""
|
"""Status of a request."""
|
||||||
WAITING = enum.auto()
|
WAITING = enum.auto()
|
||||||
WAITING_FOR_FSM = enum.auto()
|
WAITING_FOR_FSM = enum.auto()
|
||||||
|
WAITING_FOR_REMOTE_KVS = enum.auto()
|
||||||
RUNNING = enum.auto()
|
RUNNING = enum.auto()
|
||||||
PREEMPTED = enum.auto()
|
PREEMPTED = enum.auto()
|
||||||
# Note: anything after PREEMPTED will be considered
|
# Note: anything after PREEMPTED will be considered
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
@ -17,8 +18,9 @@ from vllm.config import (CompilationLevel, VllmConfig,
|
|||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
@ -1065,15 +1067,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||||
# Update KVConnector with the KVConnector metadata forward().
|
|
||||||
if has_kv_transfer_group():
|
|
||||||
get_kv_transfer_group().bind_connector_metadata(
|
|
||||||
scheduler_output.kv_connector_metadata)
|
|
||||||
|
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
if not has_kv_transfer_group():
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
# Return empty ModelRunnerOutput if there's no work to do.
|
||||||
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
return self.kv_connector_no_forward(scheduler_output)
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
attn_metadata, logits_indices, spec_decode_metadata = (
|
attn_metadata, logits_indices, spec_decode_metadata = (
|
||||||
@ -1150,17 +1151,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens):
|
num_tokens=num_input_tokens):
|
||||||
output = self.model(
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
|
||||||
|
model_output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.maybe_wait_for_kv_save()
|
||||||
|
finished_sending, finished_recving = (
|
||||||
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
hidden_states, aux_hidden_states = output
|
hidden_states, aux_hidden_states = model_output
|
||||||
else:
|
else:
|
||||||
hidden_states = output
|
hidden_states = model_output
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
@ -1341,8 +1348,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
spec_token_ids=spec_token_ids,
|
spec_token_ids=spec_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def kv_connector_no_forward(
|
||||||
|
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||||
|
# KV send/recv even if no work to do.
|
||||||
|
with set_forward_context(None, self.vllm_config):
|
||||||
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
finished_sending, finished_recving = (
|
||||||
|
self.get_finished_kv_transfers(scheduler_output))
|
||||||
|
|
||||||
|
if not finished_sending and not finished_recving:
|
||||||
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
output.finished_sending = finished_sending
|
||||||
|
output.finished_recving = finished_recving
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||||
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
kv_connector = get_kv_transfer_group()
|
||||||
|
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||||
|
assert scheduler_output.kv_connector_metadata is not None
|
||||||
|
kv_connector.bind_connector_metadata(
|
||||||
|
scheduler_output.kv_connector_metadata)
|
||||||
|
|
||||||
|
# Background KV cache transfers happen here.
|
||||||
|
# These transfers are designed to be async and the requests
|
||||||
|
# involved may be disjoint from the running requests.
|
||||||
|
# Do this here to save a collective_rpc.
|
||||||
|
kv_connector.start_load_kv(get_forward_context())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_wait_for_kv_save() -> None:
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
get_kv_transfer_group().wait_for_save()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_finished_kv_transfers(
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
return get_kv_transfer_group().get_finished(
|
||||||
|
scheduler_output.finished_req_ids)
|
||||||
|
return None, None
|
||||||
|
|
||||||
def generate_draft_token_ids(
|
def generate_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
sampled_token_ids: list[list[int]],
|
sampled_token_ids: list[list[int]],
|
||||||
@ -1813,6 +1868,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.vllm_config.compilation_config.static_forward_context,
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
|
if has_kv_transfer_group():
|
||||||
|
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||||
|
|
||||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||||
weakref.proxy(self),
|
weakref.proxy(self),
|
||||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user