From 5e83a7277f7892432375d3d41594ebfde086ca4e Mon Sep 17 00:00:00 2001 From: Yihua Cheng Date: Fri, 25 Apr 2025 22:03:38 -0500 Subject: [PATCH] [v1] [P/D] Adding LMCache KV connector for v1 (#16625) --- examples/lmcache/README.md | 56 +++++ .../cpu_offload_lmcache_v0.py} | 0 examples/lmcache/cpu_offload_lmcache_v1.py | 57 ++++++ .../disagg_prefill_lmcache_v0.py} | 0 .../configs/lmcache-decoder-config.yaml | 13 ++ .../configs/lmcache-prefiller-config.yaml | 13 ++ .../disagg_example_nixl.sh | 136 ++++++++++++ .../disagg_proxy_server.py | 193 ++++++++++++++++++ .../disagg_vllm_launcher.sh | 59 ++++++ .../lmcache/kv_cache_sharing_lmcache_v1.py | 130 ++++++++++++ .../kv_transfer/kv_connector/factory.py | 5 + .../kv_connector/v1/lmcache_connector.py | 131 ++++++++++++ 12 files changed, 793 insertions(+) create mode 100644 examples/lmcache/README.md rename examples/{offline_inference/cpu_offload_lmcache.py => lmcache/cpu_offload_lmcache_v0.py} (100%) create mode 100644 examples/lmcache/cpu_offload_lmcache_v1.py rename examples/{offline_inference/disaggregated_prefill_lmcache.py => lmcache/disagg_prefill_lmcache_v0.py} (100%) create mode 100644 examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml create mode 100644 examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml create mode 100644 examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh create mode 100644 examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py create mode 100644 examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh create mode 100644 examples/lmcache/kv_cache_sharing_lmcache_v1.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py diff --git a/examples/lmcache/README.md b/examples/lmcache/README.md new file mode 100644 index 0000000000000..237d0aebd00d3 --- /dev/null +++ b/examples/lmcache/README.md @@ -0,0 +1,56 @@ +# LMCache Examples + +This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing. + +## 1. Disaggregated Prefill in vLLM v1 + +This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node. + +### Prerequisites + +- Install [LMCache](https://github.com/ai-dynamo/lmcache) +- Install [NIXL](https://github.com/ai-dynamo/nixl) +- At least 2 GPUs +- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct. + +### Usage + +Run +`cd disagg_prefill_lmcache_v1` +to get into `disagg_prefill_lmcache_v1` folder, and then run + +```bash +bash disagg_example_nixl.sh +``` + +to run disaggregated prefill and benchmark the performance. + +### Components + +#### Server Scripts +- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server. +- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder +- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example + +#### Configuration +- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server +- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server + +#### Log Files +The main script generates several log files: +- `prefiller.log` - Logs from the prefill server +- `decoder.log` - Logs from the decode server +- `proxy.log` - Logs from the proxy server + +## 2. CPU Offload Examples + +- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0 +- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1 + +## 3. KV Cache Sharing + +The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances. + +## 4. Disaggregated Prefill in vLLM v0 + +The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0. diff --git a/examples/offline_inference/cpu_offload_lmcache.py b/examples/lmcache/cpu_offload_lmcache_v0.py similarity index 100% rename from examples/offline_inference/cpu_offload_lmcache.py rename to examples/lmcache/cpu_offload_lmcache_v0.py diff --git a/examples/lmcache/cpu_offload_lmcache_v1.py b/examples/lmcache/cpu_offload_lmcache_v1.py new file mode 100644 index 0000000000000..f44075a36965f --- /dev/null +++ b/examples/lmcache/cpu_offload_lmcache_v1.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of cpu offloading +with LMCache in vLLM v1. + +Note that lmcache needs to be installed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Enable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "True" +# Set local CPU memory limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + +# This example script runs two requests with a shared prefix. +shared_prompt = "Hello, how are you?" * 1000 +first_prompt = [ + shared_prompt + "Hello, my name is", +] +second_prompt = [ + shared_prompt + "Tell me a very long story", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + +ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') +# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB +# memory. Reduce the value if your GPU has less memory. +# Note that LMCache is not compatible with chunked prefill for now. +llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8) + +# Should be able to see logs like the following: +# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0` +# This indicates that the KV cache has been stored in LMCache. +outputs = llm.generate(first_prompt, sampling_params) +for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + +# Clean up lmcache backend +LMCacheEngineBuilder.destroy(ENGINE_NAME) diff --git a/examples/offline_inference/disaggregated_prefill_lmcache.py b/examples/lmcache/disagg_prefill_lmcache_v0.py similarity index 100% rename from examples/offline_inference/disaggregated_prefill_lmcache.py rename to examples/lmcache/disagg_prefill_lmcache_v0.py diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml new file mode 100644 index 0000000000000..c3f5a0ae69c06 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "receiver" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml new file mode 100644 index 0000000000000..8b0e82958a64c --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "sender" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh new file mode 100644 index 0000000000000..df8a412935049 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh @@ -0,0 +1,136 @@ +#!/bin/bash + +echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change." + + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # can you check if the number of GPUs are >=2 via nvidia-smi? + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + python -c "import $1" > /dev/null 2>&1 + if [ $? -ne 0 ]; then + if [ "$1" == "nixl" ]; then + echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation." + else + echo "$1 is not installed. Please install it via pip install $1." + fi + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == “this whole process-group” + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=1200 + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server" + return 1 + fi + + sleep 1 + done +} + + +main() { + check_hf_token + check_num_gpus + ensure_python_library_installed lmcache + ensure_python_library_installed nixl + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching prefiller, decoder and proxy..." + echo "Please check prefiller.log, decoder.log and proxy.log for logs." + + bash disagg_vllm_launcher.sh prefiller \ + > >(tee prefiller.log) 2>&1 & + prefiller_pid=$! + PIDS+=($prefiller_pid) + + bash disagg_vllm_launcher.sh decoder \ + > >(tee decoder.log) 2>&1 & + decoder_pid=$! + PIDS+=($decoder_pid) + + python3 disagg_proxy_server.py \ + --host localhost \ + --port 9000 \ + --prefiller-host localhost \ + --prefiller-port 8100 \ + --decoder-host localhost \ + --decoder-port 8200 \ + > >(tee proxy.log) 2>&1 & + proxy_pid=$! + PIDS+=($proxy_pid) + + wait_for_server 8100 + wait_for_server 8200 + wait_for_server 9000 + + echo "All servers are up. Starting benchmark..." + + # begin benchmark + cd ../../../benchmarks/ + python benchmark_serving.py --port 9000 --seed $(date +%s) \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup + +} + +main \ No newline at end of file diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py new file mode 100644 index 0000000000000..8db93bc8931b2 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['max_tokens'] = 1 + if 'max_completion_tokens' in req_data: + req_data['max_completion_tokens'] = 1 + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh new file mode 100644 index 0000000000000..831ef0bb574bf --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=0 \ + vllm serve $MODEL \ + --port 8100 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=1 \ + vllm serve $MODEL \ + --port 8200 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode" + exit 1 +fi diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py new file mode 100644 index 0000000000000..af1b4351dd54c --- /dev/null +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of remote KV cache sharing +with LMCache. +We will launch 2 vllm instances, and launch an additional LMCache server. +KV cache is transferred in the following manner: +(1) vLLM instance 1 -> LMCache server (KV cache store). +(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). + +Note that lmcache needs to be installed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# The port to start LMCache server +port = 8100 +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server +os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" + +prompts = [ + "Hello, how are you?" * 1000, +] + + +def run_store(store_done, prompts): + # We use GPU 0 for KV cache store process. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print("KV cache store is finished.") + store_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_retrieve(store_done, prompts, timeout=1): + # We use GPU 1 for KV cache retrieve process. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + print("Waiting for KV cache store to finish...") + store_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + server_proc = subprocess.Popen([ + "python", "-m", "lmcache.experimental.server", "localhost", + str(port) + ]) + return server_proc + + +def main(): + store_done = Event() + store_process = Process(target=run_store, args=(store_done, prompts)) + retrieve_process = Process(target=run_retrieve, args=(store_done, prompts)) + lmcache_server_process = run_lmcache_server(port) + + # Start KV cache store process + store_process.start() + + # Start KV cache retrieve process + retrieve_process.start() + + # Clean up the processes + store_process.join() + retrieve_process.terminate() + lmcache_server_process.terminate() + lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 665ea2f5ba011..6532c101a4f6a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -100,3 +100,8 @@ KVConnectorFactory.register_connector( "SharedStorageConnector", "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", "SharedStorageConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnectorV1", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", + "LMCacheConnectorV1") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 0000000000000..e07f185f0dd81 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +import torch +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class LMCacheConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self._lmcache_engine.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self._lmcache_engine.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, + **kwargs) + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self._lmcache_engine.wait_for_save() + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._lmcache_engine.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self._lmcache_engine.update_state_after_alloc(request, + num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self._lmcache_engine.build_connector_meta(scheduler_output)