mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 16:04:42 +08:00
[v1] [P/D] Adding LMCache KV connector for v1 (#16625)
This commit is contained in:
parent
68af5f6c5c
commit
5e83a7277f
56
examples/lmcache/README.md
Normal file
56
examples/lmcache/README.md
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# LMCache Examples
|
||||||
|
|
||||||
|
This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing.
|
||||||
|
|
||||||
|
## 1. Disaggregated Prefill in vLLM v1
|
||||||
|
|
||||||
|
This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node.
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Install [LMCache](https://github.com/ai-dynamo/lmcache)
|
||||||
|
- Install [NIXL](https://github.com/ai-dynamo/nixl)
|
||||||
|
- At least 2 GPUs
|
||||||
|
- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Run
|
||||||
|
`cd disagg_prefill_lmcache_v1`
|
||||||
|
to get into `disagg_prefill_lmcache_v1` folder, and then run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash disagg_example_nixl.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
to run disaggregated prefill and benchmark the performance.
|
||||||
|
|
||||||
|
### Components
|
||||||
|
|
||||||
|
#### Server Scripts
|
||||||
|
- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server.
|
||||||
|
- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder
|
||||||
|
- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example
|
||||||
|
|
||||||
|
#### Configuration
|
||||||
|
- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server
|
||||||
|
- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server
|
||||||
|
|
||||||
|
#### Log Files
|
||||||
|
The main script generates several log files:
|
||||||
|
- `prefiller.log` - Logs from the prefill server
|
||||||
|
- `decoder.log` - Logs from the decode server
|
||||||
|
- `proxy.log` - Logs from the proxy server
|
||||||
|
|
||||||
|
## 2. CPU Offload Examples
|
||||||
|
|
||||||
|
- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0
|
||||||
|
- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1
|
||||||
|
|
||||||
|
## 3. KV Cache Sharing
|
||||||
|
|
||||||
|
The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances.
|
||||||
|
|
||||||
|
## 4. Disaggregated Prefill in vLLM v0
|
||||||
|
|
||||||
|
The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0.
|
||||||
57
examples/lmcache/cpu_offload_lmcache_v1.py
Normal file
57
examples/lmcache/cpu_offload_lmcache_v1.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This file demonstrates the example usage of cpu offloading
|
||||||
|
with LMCache in vLLM v1.
|
||||||
|
|
||||||
|
Note that lmcache needs to be installed to run this example.
|
||||||
|
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
||||||
|
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import KVTransferConfig
|
||||||
|
|
||||||
|
# LMCache-related environment variables
|
||||||
|
# Use experimental features in LMCache
|
||||||
|
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||||
|
# LMCache is set to use 256 tokens per chunk
|
||||||
|
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||||
|
# Enable local CPU backend in LMCache
|
||||||
|
os.environ["LMCACHE_LOCAL_CPU"] = "True"
|
||||||
|
# Set local CPU memory limit to 5.0 GB
|
||||||
|
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||||
|
|
||||||
|
# This example script runs two requests with a shared prefix.
|
||||||
|
shared_prompt = "Hello, how are you?" * 1000
|
||||||
|
first_prompt = [
|
||||||
|
shared_prompt + "Hello, my name is",
|
||||||
|
]
|
||||||
|
second_prompt = [
|
||||||
|
shared_prompt + "Tell me a very long story",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||||
|
|
||||||
|
ktc = KVTransferConfig.from_cli(
|
||||||
|
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
|
||||||
|
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||||
|
# memory. Reduce the value if your GPU has less memory.
|
||||||
|
# Note that LMCache is not compatible with chunked prefill for now.
|
||||||
|
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
kv_transfer_config=ktc,
|
||||||
|
max_model_len=8000,
|
||||||
|
gpu_memory_utilization=0.8)
|
||||||
|
|
||||||
|
# Should be able to see logs like the following:
|
||||||
|
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
|
||||||
|
# This indicates that the KV cache has been stored in LMCache.
|
||||||
|
outputs = llm.generate(first_prompt, sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
# Clean up lmcache backend
|
||||||
|
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
local_cpu: False
|
||||||
|
max_local_cpu_size: 0
|
||||||
|
#local_disk:
|
||||||
|
max_local_disk_size: 0
|
||||||
|
remote_serde: NULL
|
||||||
|
|
||||||
|
enable_nixl: True
|
||||||
|
nixl_role: "receiver"
|
||||||
|
nixl_peer_host: "localhost"
|
||||||
|
nixl_peer_port: 55555
|
||||||
|
nixl_buffer_size: 1073741824 # 1GB
|
||||||
|
nixl_buffer_device: "cuda"
|
||||||
|
nixl_enable_gc: True
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
local_cpu: False
|
||||||
|
max_local_cpu_size: 0
|
||||||
|
#local_disk:
|
||||||
|
max_local_disk_size: 0
|
||||||
|
remote_serde: NULL
|
||||||
|
|
||||||
|
enable_nixl: True
|
||||||
|
nixl_role: "sender"
|
||||||
|
nixl_peer_host: "localhost"
|
||||||
|
nixl_peer_port: 55555
|
||||||
|
nixl_buffer_size: 1073741824 # 1GB
|
||||||
|
nixl_buffer_device: "cuda"
|
||||||
|
nixl_enable_gc: True
|
||||||
@ -0,0 +1,136 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."
|
||||||
|
|
||||||
|
|
||||||
|
PIDS=()
|
||||||
|
|
||||||
|
# Switch to the directory of the current script
|
||||||
|
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||||
|
|
||||||
|
check_hf_token() {
|
||||||
|
if [ -z "$HF_TOKEN" ]; then
|
||||||
|
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ "$HF_TOKEN" != hf_* ]]; then
|
||||||
|
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "HF_TOKEN is set and valid."
|
||||||
|
}
|
||||||
|
|
||||||
|
check_num_gpus() {
|
||||||
|
# can you check if the number of GPUs are >=2 via nvidia-smi?
|
||||||
|
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||||
|
if [ "$num_gpus" -lt 2 ]; then
|
||||||
|
echo "You need at least 2 GPUs to run disaggregated prefill."
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Found $num_gpus GPUs."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_python_library_installed() {
|
||||||
|
echo "Checking if $1 is installed..."
|
||||||
|
python -c "import $1" > /dev/null 2>&1
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
if [ "$1" == "nixl" ]; then
|
||||||
|
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
|
||||||
|
else
|
||||||
|
echo "$1 is not installed. Please install it via pip install $1."
|
||||||
|
fi
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "$1 is installed."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
echo "Stopping everything…"
|
||||||
|
trap - INT TERM # prevent re-entrancy
|
||||||
|
kill -- -$$ # negative PID == “this whole process-group”
|
||||||
|
wait # reap children so we don't leave zombies
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_server() {
|
||||||
|
local port=$1
|
||||||
|
local timeout_seconds=1200
|
||||||
|
local start_time=$(date +%s)
|
||||||
|
|
||||||
|
echo "Waiting for server on port $port..."
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
local now=$(date +%s)
|
||||||
|
if (( now - start_time >= timeout_seconds )); then
|
||||||
|
echo "Timeout waiting for server"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
main() {
|
||||||
|
check_hf_token
|
||||||
|
check_num_gpus
|
||||||
|
ensure_python_library_installed lmcache
|
||||||
|
ensure_python_library_installed nixl
|
||||||
|
ensure_python_library_installed pandas
|
||||||
|
ensure_python_library_installed datasets
|
||||||
|
ensure_python_library_installed vllm
|
||||||
|
|
||||||
|
trap cleanup INT
|
||||||
|
trap cleanup USR1
|
||||||
|
trap cleanup TERM
|
||||||
|
|
||||||
|
echo "Launching prefiller, decoder and proxy..."
|
||||||
|
echo "Please check prefiller.log, decoder.log and proxy.log for logs."
|
||||||
|
|
||||||
|
bash disagg_vllm_launcher.sh prefiller \
|
||||||
|
> >(tee prefiller.log) 2>&1 &
|
||||||
|
prefiller_pid=$!
|
||||||
|
PIDS+=($prefiller_pid)
|
||||||
|
|
||||||
|
bash disagg_vllm_launcher.sh decoder \
|
||||||
|
> >(tee decoder.log) 2>&1 &
|
||||||
|
decoder_pid=$!
|
||||||
|
PIDS+=($decoder_pid)
|
||||||
|
|
||||||
|
python3 disagg_proxy_server.py \
|
||||||
|
--host localhost \
|
||||||
|
--port 9000 \
|
||||||
|
--prefiller-host localhost \
|
||||||
|
--prefiller-port 8100 \
|
||||||
|
--decoder-host localhost \
|
||||||
|
--decoder-port 8200 \
|
||||||
|
> >(tee proxy.log) 2>&1 &
|
||||||
|
proxy_pid=$!
|
||||||
|
PIDS+=($proxy_pid)
|
||||||
|
|
||||||
|
wait_for_server 8100
|
||||||
|
wait_for_server 8200
|
||||||
|
wait_for_server 9000
|
||||||
|
|
||||||
|
echo "All servers are up. Starting benchmark..."
|
||||||
|
|
||||||
|
# begin benchmark
|
||||||
|
cd ../../../benchmarks/
|
||||||
|
python benchmark_serving.py --port 9000 --seed $(date +%s) \
|
||||||
|
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||||
|
--dataset-name random --random-input-len 7500 --random-output-len 200 \
|
||||||
|
--num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log
|
||||||
|
|
||||||
|
echo "Benchmarking done. Cleaning up..."
|
||||||
|
|
||||||
|
cleanup
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
main
|
||||||
@ -0,0 +1,193 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
Lifespan context manager to handle startup and shutdown events.
|
||||||
|
"""
|
||||||
|
# Startup: Initialize clients
|
||||||
|
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
|
||||||
|
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
|
||||||
|
|
||||||
|
app.state.prefill_client = httpx.AsyncClient(timeout=None,
|
||||||
|
base_url=prefiller_base_url)
|
||||||
|
app.state.decode_client = httpx.AsyncClient(timeout=None,
|
||||||
|
base_url=decoder_base_url)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown: Close clients
|
||||||
|
await app.state.prefill_client.aclose()
|
||||||
|
await app.state.decode_client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# Update FastAPI app initialization to use lifespan
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
class StatsCalculator:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._stats = []
|
||||||
|
self._last_log_time = time.time()
|
||||||
|
|
||||||
|
def add(self, value):
|
||||||
|
self._stats.append(value)
|
||||||
|
if time.time() - self._last_log_time > 5:
|
||||||
|
self._log_stats()
|
||||||
|
self._last_log_time = time.time()
|
||||||
|
|
||||||
|
def _log_stats(self):
|
||||||
|
# Print average, median, and 99th percentile
|
||||||
|
np_arr = np.array(self._stats)
|
||||||
|
output_str = f"\nNum requests: {len(self._stats)}" + \
|
||||||
|
"\nPrefill node TTFT stats:" + \
|
||||||
|
f"\n - Average (ms): {np.mean(np_arr)}" + \
|
||||||
|
f"\n - Median (ms): {np.median(np_arr)}" + \
|
||||||
|
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
|
||||||
|
print("===============================", output_str,
|
||||||
|
"===============================")
|
||||||
|
|
||||||
|
|
||||||
|
stats_calculator = StatsCalculator()
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--prefiller-host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--prefiller-port", type=int, default=8100)
|
||||||
|
parser.add_argument("--decoder-host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--decoder-port", type=int, default=8200)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize variables to hold the persistent clients
|
||||||
|
app.state.prefill_client = None
|
||||||
|
app.state.decode_client = None
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
|
||||||
|
req_data: dict):
|
||||||
|
"""
|
||||||
|
Send a request to a service using a persistent client.
|
||||||
|
"""
|
||||||
|
req_data = req_data.copy()
|
||||||
|
req_data['max_tokens'] = 1
|
||||||
|
if 'max_completion_tokens' in req_data:
|
||||||
|
req_data['max_completion_tokens'] = 1
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
response = await client.post(endpoint, json=req_data, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
|
||||||
|
req_data: dict):
|
||||||
|
"""
|
||||||
|
Asynchronously stream the response from a service using a persistent client.
|
||||||
|
"""
|
||||||
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
async with client.stream("POST", endpoint, json=req_data,
|
||||||
|
headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def handle_completions(request: Request):
|
||||||
|
global counter, stats_calculator
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
try:
|
||||||
|
req_data = await request.json()
|
||||||
|
|
||||||
|
# Send request to prefill service, ignore the response
|
||||||
|
await send_request_to_service(app.state.prefill_client, "/completions",
|
||||||
|
req_data)
|
||||||
|
|
||||||
|
et = time.time()
|
||||||
|
stats_calculator.add(et - st)
|
||||||
|
|
||||||
|
# Stream response from decode service
|
||||||
|
async def generate_stream():
|
||||||
|
async for chunk in stream_service_response(app.state.decode_client,
|
||||||
|
"/completions",
|
||||||
|
req_data):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(generate_stream(),
|
||||||
|
media_type="application/json")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
print("Error occurred in disagg prefill proxy server"
|
||||||
|
" - completions endpoint")
|
||||||
|
print(e)
|
||||||
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def handle_chat_completions(request: Request):
|
||||||
|
global counter, stats_calculator
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
try:
|
||||||
|
req_data = await request.json()
|
||||||
|
|
||||||
|
# Send request to prefill service, ignore the response
|
||||||
|
await send_request_to_service(app.state.prefill_client,
|
||||||
|
"/chat/completions", req_data)
|
||||||
|
|
||||||
|
et = time.time()
|
||||||
|
stats_calculator.add(et - st)
|
||||||
|
|
||||||
|
# Stream response from decode service
|
||||||
|
async def generate_stream():
|
||||||
|
async for chunk in stream_service_response(app.state.decode_client,
|
||||||
|
"/chat/completions",
|
||||||
|
req_data):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(generate_stream(),
|
||||||
|
media_type="application/json")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
print("Error occurred in disagg prefill proxy server "
|
||||||
|
" - chat completions endpoint")
|
||||||
|
print(e)
|
||||||
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
global global_args
|
||||||
|
global_args = parse_args()
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
|
if [[ $# -lt 1 ]]; then
|
||||||
|
echo "Usage: $0 <prefiller | decoder> [model]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ $# -eq 1 ]]; then
|
||||||
|
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
MODEL="meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
else
|
||||||
|
echo "Using model: $2"
|
||||||
|
MODEL=$2
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [[ $1 == "prefiller" ]]; then
|
||||||
|
# Prefiller listens on port 8100
|
||||||
|
prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml
|
||||||
|
|
||||||
|
UCX_TLS=cuda_ipc,cuda_copy,tcp \
|
||||||
|
LMCACHE_CONFIG_FILE=$prefill_config_file \
|
||||||
|
LMCACHE_USE_EXPERIMENTAL=True \
|
||||||
|
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
|
||||||
|
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
vllm serve $MODEL \
|
||||||
|
--port 8100 \
|
||||||
|
--disable-log-requests \
|
||||||
|
--enforce-eager \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
|
||||||
|
|
||||||
|
|
||||||
|
elif [[ $1 == "decoder" ]]; then
|
||||||
|
# Decoder listens on port 8200
|
||||||
|
decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml
|
||||||
|
|
||||||
|
UCX_TLS=cuda_ipc,cuda_copy,tcp \
|
||||||
|
LMCACHE_CONFIG_FILE=$decode_config_file \
|
||||||
|
LMCACHE_USE_EXPERIMENTAL=True \
|
||||||
|
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
|
||||||
|
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||||
|
CUDA_VISIBLE_DEVICES=1 \
|
||||||
|
vllm serve $MODEL \
|
||||||
|
--port 8200 \
|
||||||
|
--disable-log-requests \
|
||||||
|
--enforce-eager \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
|
||||||
|
|
||||||
|
|
||||||
|
else
|
||||||
|
echo "Invalid role: $1"
|
||||||
|
echo "Should be either prefill, decode"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
130
examples/lmcache/kv_cache_sharing_lmcache_v1.py
Normal file
130
examples/lmcache/kv_cache_sharing_lmcache_v1.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This file demonstrates the example usage of remote KV cache sharing
|
||||||
|
with LMCache.
|
||||||
|
We will launch 2 vllm instances, and launch an additional LMCache server.
|
||||||
|
KV cache is transferred in the following manner:
|
||||||
|
(1) vLLM instance 1 -> LMCache server (KV cache store).
|
||||||
|
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
|
||||||
|
|
||||||
|
Note that lmcache needs to be installed to run this example.
|
||||||
|
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from multiprocessing import Event, Process
|
||||||
|
|
||||||
|
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
||||||
|
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import KVTransferConfig
|
||||||
|
|
||||||
|
# LMCache-related environment variables
|
||||||
|
# The port to start LMCache server
|
||||||
|
port = 8100
|
||||||
|
# Use experimental features in LMCache
|
||||||
|
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||||
|
# LMCache is set to use 256 tokens per chunk
|
||||||
|
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||||
|
# Disable local CPU backend in LMCache
|
||||||
|
os.environ["LMCACHE_LOCAL_CPU"] = "False"
|
||||||
|
# Set local CPU memory buffer limit to 5.0 GB
|
||||||
|
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||||
|
# Set the remote URL for LMCache server
|
||||||
|
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
|
||||||
|
# Set the serializer/deserializer between vllm and LMCache server
|
||||||
|
# `naive` indicates using raw bytes of the tensor without any compression
|
||||||
|
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, how are you?" * 1000,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_store(store_done, prompts):
|
||||||
|
# We use GPU 0 for KV cache store process.
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||||
|
|
||||||
|
ktc = KVTransferConfig.from_cli(
|
||||||
|
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
|
||||||
|
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||||
|
# memory. Reduce the value if your GPU has less memory.
|
||||||
|
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
|
kv_transfer_config=ktc,
|
||||||
|
max_model_len=8000,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
enforce_eager=True)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Generated text: {generated_text!r}")
|
||||||
|
print("KV cache store is finished.")
|
||||||
|
store_done.set()
|
||||||
|
|
||||||
|
# Clean up lmcache backend
|
||||||
|
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def run_retrieve(store_done, prompts, timeout=1):
|
||||||
|
# We use GPU 1 for KV cache retrieve process.
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||||
|
|
||||||
|
ktc = KVTransferConfig.from_cli(
|
||||||
|
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
|
||||||
|
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||||
|
# of memory. Reduce the value if your GPU has less memory.
|
||||||
|
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
|
kv_transfer_config=ktc,
|
||||||
|
max_model_len=8000,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
enforce_eager=True)
|
||||||
|
|
||||||
|
print("Waiting for KV cache store to finish...")
|
||||||
|
store_done.wait()
|
||||||
|
time.sleep(timeout)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
# Clean up lmcache backend
|
||||||
|
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def run_lmcache_server(port):
|
||||||
|
server_proc = subprocess.Popen([
|
||||||
|
"python", "-m", "lmcache.experimental.server", "localhost",
|
||||||
|
str(port)
|
||||||
|
])
|
||||||
|
return server_proc
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
store_done = Event()
|
||||||
|
store_process = Process(target=run_store, args=(store_done, prompts))
|
||||||
|
retrieve_process = Process(target=run_retrieve, args=(store_done, prompts))
|
||||||
|
lmcache_server_process = run_lmcache_server(port)
|
||||||
|
|
||||||
|
# Start KV cache store process
|
||||||
|
store_process.start()
|
||||||
|
|
||||||
|
# Start KV cache retrieve process
|
||||||
|
retrieve_process.start()
|
||||||
|
|
||||||
|
# Clean up the processes
|
||||||
|
store_process.join()
|
||||||
|
retrieve_process.terminate()
|
||||||
|
lmcache_server_process.terminate()
|
||||||
|
lmcache_server_process.wait()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -100,3 +100,8 @@ KVConnectorFactory.register_connector(
|
|||||||
"SharedStorageConnector",
|
"SharedStorageConnector",
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
||||||
"SharedStorageConnector")
|
"SharedStorageConnector")
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector(
|
||||||
|
"LMCacheConnectorV1",
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||||
|
"LMCacheConnectorV1")
|
||||||
|
|||||||
@ -0,0 +1,131 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
|
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Worker-side methods
|
||||||
|
# ==============================
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Start loading the KV cache from the connector to vLLM's paged
|
||||||
|
KV buffer. This is called from the forward context before the
|
||||||
|
forward pass to enable async loading during model execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_context (ForwardContext): the forward context.
|
||||||
|
**kwargs: additional arguments for the load operation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number of elements in kv_caches and layer_names should be
|
||||||
|
the same.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Block until the KV for a specific layer is loaded into vLLM's
|
||||||
|
paged buffer. This is called from within attention layer to ensure
|
||||||
|
async copying from start_load_kv is complete.
|
||||||
|
|
||||||
|
This interface will be useful for layer-by-layer pipelining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name: the name of that layer
|
||||||
|
"""
|
||||||
|
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||||
|
to the connector. This is called from within attention layer to
|
||||||
|
enable async copying during execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_name (str): the name of the layer.
|
||||||
|
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||||
|
layer in vLLM.
|
||||||
|
attn_metadata (AttentionMetadata): the attention metadata.
|
||||||
|
**kwargs: additional arguments for the save operation.
|
||||||
|
"""
|
||||||
|
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
"""
|
||||||
|
Block until all the save operations is done. This is called
|
||||||
|
as the forward context exits to ensure that the async saving
|
||||||
|
from save_kv_layer is complete before finishing the forward.
|
||||||
|
|
||||||
|
This prevents overwrites of paged KV buffer before saving done.
|
||||||
|
"""
|
||||||
|
self._lmcache_engine.wait_for_save()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Scheduler-side methods
|
||||||
|
# ==============================
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of new tokens that can be loaded from the
|
||||||
|
external KV cache beyond the num_computed_tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
num_computed_tokens (int): the number of locally
|
||||||
|
computed tokens for this request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the number of tokens that can be loaded from the
|
||||||
|
external KV cache beyond what is already computed.
|
||||||
|
"""
|
||||||
|
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||||
|
request, num_computed_tokens)
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
num_external_tokens: int):
|
||||||
|
"""
|
||||||
|
Update KVConnector state after block allocation.
|
||||||
|
"""
|
||||||
|
self._lmcache_engine.update_state_after_alloc(request,
|
||||||
|
num_external_tokens)
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
|
"""
|
||||||
|
Build the connector metadata for this step.
|
||||||
|
|
||||||
|
This function should NOT modify fields in the scheduler_output.
|
||||||
|
Also, calling this function will reset the state of the connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
|
"""
|
||||||
|
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||||
Loading…
x
Reference in New Issue
Block a user