[v1] [P/D] Adding LMCache KV connector for v1 (#16625)

This commit is contained in:
Yihua Cheng 2025-04-25 22:03:38 -05:00 committed by GitHub
parent 68af5f6c5c
commit 5e83a7277f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 793 additions and 0 deletions

View File

@ -0,0 +1,56 @@
# LMCache Examples
This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing.
## 1. Disaggregated Prefill in vLLM v1
This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node.
### Prerequisites
- Install [LMCache](https://github.com/ai-dynamo/lmcache)
- Install [NIXL](https://github.com/ai-dynamo/nixl)
- At least 2 GPUs
- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct.
### Usage
Run
`cd disagg_prefill_lmcache_v1`
to get into `disagg_prefill_lmcache_v1` folder, and then run
```bash
bash disagg_example_nixl.sh
```
to run disaggregated prefill and benchmark the performance.
### Components
#### Server Scripts
- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server.
- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder
- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example
#### Configuration
- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server
- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server
#### Log Files
The main script generates several log files:
- `prefiller.log` - Logs from the prefill server
- `decoder.log` - Logs from the decode server
- `proxy.log` - Logs from the proxy server
## 2. CPU Offload Examples
- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0
- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1
## 3. KV Cache Sharing
The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances.
## 4. Disaggregated Prefill in vLLM v0
The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0.

View File

@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache in vLLM v1.
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# LMCache-related environment variables
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Enable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "True"
# Set local CPU memory limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# This example script runs two requests with a shared prefix.
shared_prompt = "Hello, how are you?" * 1000
first_prompt = [
shared_prompt + "Hello, my name is",
]
second_prompt = [
shared_prompt + "Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note that LMCache is not compatible with chunked prefill for now.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8)
# Should be able to see logs like the following:
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
# This indicates that the KV cache has been stored in LMCache.
outputs = llm.generate(first_prompt, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)

View File

@ -0,0 +1,13 @@
local_cpu: False
max_local_cpu_size: 0
#local_disk:
max_local_disk_size: 0
remote_serde: NULL
enable_nixl: True
nixl_role: "receiver"
nixl_peer_host: "localhost"
nixl_peer_port: 55555
nixl_buffer_size: 1073741824 # 1GB
nixl_buffer_device: "cuda"
nixl_enable_gc: True

View File

@ -0,0 +1,13 @@
local_cpu: False
max_local_cpu_size: 0
#local_disk:
max_local_disk_size: 0
remote_serde: NULL
enable_nixl: True
nixl_role: "sender"
nixl_peer_host: "localhost"
nixl_peer_port: 55555
nixl_buffer_size: 1073741824 # 1GB
nixl_buffer_device: "cuda"
nixl_enable_gc: True

View File

@ -0,0 +1,136 @@
#!/bin/bash
echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."
PIDS=()
# Switch to the directory of the current script
cd "$(dirname "${BASH_SOURCE[0]}")"
check_hf_token() {
if [ -z "$HF_TOKEN" ]; then
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
exit 1
fi
if [[ "$HF_TOKEN" != hf_* ]]; then
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
exit 1
fi
echo "HF_TOKEN is set and valid."
}
check_num_gpus() {
# can you check if the number of GPUs are >=2 via nvidia-smi?
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ "$num_gpus" -lt 2 ]; then
echo "You need at least 2 GPUs to run disaggregated prefill."
exit 1
else
echo "Found $num_gpus GPUs."
fi
}
ensure_python_library_installed() {
echo "Checking if $1 is installed..."
python -c "import $1" > /dev/null 2>&1
if [ $? -ne 0 ]; then
if [ "$1" == "nixl" ]; then
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
else
echo "$1 is not installed. Please install it via pip install $1."
fi
exit 1
else
echo "$1 is installed."
fi
}
cleanup() {
echo "Stopping everything…"
trap - INT TERM # prevent re-entrancy
kill -- -$$ # negative PID == “this whole process-group”
wait # reap children so we don't leave zombies
exit 0
}
wait_for_server() {
local port=$1
local timeout_seconds=1200
local start_time=$(date +%s)
echo "Waiting for server on port $port..."
while true; do
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
return 0
fi
local now=$(date +%s)
if (( now - start_time >= timeout_seconds )); then
echo "Timeout waiting for server"
return 1
fi
sleep 1
done
}
main() {
check_hf_token
check_num_gpus
ensure_python_library_installed lmcache
ensure_python_library_installed nixl
ensure_python_library_installed pandas
ensure_python_library_installed datasets
ensure_python_library_installed vllm
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
echo "Launching prefiller, decoder and proxy..."
echo "Please check prefiller.log, decoder.log and proxy.log for logs."
bash disagg_vllm_launcher.sh prefiller \
> >(tee prefiller.log) 2>&1 &
prefiller_pid=$!
PIDS+=($prefiller_pid)
bash disagg_vllm_launcher.sh decoder \
> >(tee decoder.log) 2>&1 &
decoder_pid=$!
PIDS+=($decoder_pid)
python3 disagg_proxy_server.py \
--host localhost \
--port 9000 \
--prefiller-host localhost \
--prefiller-port 8100 \
--decoder-host localhost \
--decoder-port 8200 \
> >(tee proxy.log) 2>&1 &
proxy_pid=$!
PIDS+=($proxy_pid)
wait_for_server 8100
wait_for_server 8200
wait_for_server 9000
echo "All servers are up. Starting benchmark..."
# begin benchmark
cd ../../../benchmarks/
python benchmark_serving.py --port 9000 --seed $(date +%s) \
--model meta-llama/Llama-3.1-8B-Instruct \
--dataset-name random --random-input-len 7500 --random-output-len 200 \
--num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log
echo "Benchmarking done. Cleaning up..."
cleanup
}
main

View File

@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import time
from contextlib import asynccontextmanager
import httpx
import numpy as np
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
app.state.prefill_client = httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url)
app.state.decode_client = httpx.AsyncClient(timeout=None,
base_url=decoder_base_url)
yield
# Shutdown: Close clients
await app.state.prefill_client.aclose()
await app.state.decode_client.aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
class StatsCalculator:
def __init__(self):
self._stats = []
self._last_log_time = time.time()
def add(self, value):
self._stats.append(value)
if time.time() - self._last_log_time > 5:
self._log_stats()
self._last_log_time = time.time()
def _log_stats(self):
# Print average, median, and 99th percentile
np_arr = np.array(self._stats)
output_str = f"\nNum requests: {len(self._stats)}" + \
"\nPrefill node TTFT stats:" + \
f"\n - Average (ms): {np.mean(np_arr)}" + \
f"\n - Median (ms): {np.median(np_arr)}" + \
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
print("===============================", output_str,
"===============================")
stats_calculator = StatsCalculator()
counter = 0
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-host", type=str, default="localhost")
parser.add_argument("--prefiller-port", type=int, default=8100)
parser.add_argument("--decoder-host", type=str, default="localhost")
parser.add_argument("--decoder-port", type=int, default=8200)
args = parser.parse_args()
return args
# Initialize variables to hold the persistent clients
app.state.prefill_client = None
app.state.decode_client = None
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
"""
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
req_data['max_tokens'] = 1
if 'max_completion_tokens' in req_data:
req_data['max_completion_tokens'] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with client.stream("POST", endpoint, json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
@app.post("/v1/completions")
async def handle_completions(request: Request):
global counter, stats_calculator
counter += 1
st = time.time()
try:
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, "/completions",
req_data)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/completions",
req_data):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
global counter, stats_calculator
counter += 1
st = time.time()
try:
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client,
"/chat/completions", req_data)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/chat/completions",
req_data):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server "
" - chat completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -0,0 +1,59 @@
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
if [[ $# -lt 1 ]]; then
echo "Usage: $0 <prefiller | decoder> [model]"
exit 1
fi
if [[ $# -eq 1 ]]; then
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
MODEL="meta-llama/Llama-3.1-8B-Instruct"
else
echo "Using model: $2"
MODEL=$2
fi
if [[ $1 == "prefiller" ]]; then
# Prefiller listens on port 8100
prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml
UCX_TLS=cuda_ipc,cuda_copy,tcp \
LMCACHE_CONFIG_FILE=$prefill_config_file \
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=0 \
vllm serve $MODEL \
--port 8100 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
elif [[ $1 == "decoder" ]]; then
# Decoder listens on port 8200
decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml
UCX_TLS=cuda_ipc,cuda_copy,tcp \
LMCACHE_CONFIG_FILE=$decode_config_file \
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=1 \
vllm serve $MODEL \
--port 8200 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
else
echo "Invalid role: $1"
echo "Should be either prefill, decode"
exit 1
fi

View File

@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of remote KV cache sharing
with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server.
KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
from multiprocessing import Event, Process
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# LMCache-related environment variables
# The port to start LMCache server
port = 8100
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Disable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "False"
# Set local CPU memory buffer limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# Set the remote URL for LMCache server
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
prompts = [
"Hello, how are you?" * 1000,
]
def run_store(store_done, prompts):
# We use GPU 0 for KV cache store process.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("KV cache store is finished.")
store_done.set()
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def run_retrieve(store_done, prompts, timeout=1):
# We use GPU 1 for KV cache retrieve process.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
print("Waiting for KV cache store to finish...")
store_done.wait()
time.sleep(timeout)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def run_lmcache_server(port):
server_proc = subprocess.Popen([
"python", "-m", "lmcache.experimental.server", "localhost",
str(port)
])
return server_proc
def main():
store_done = Event()
store_process = Process(target=run_store, args=(store_done, prompts))
retrieve_process = Process(target=run_retrieve, args=(store_done, prompts))
lmcache_server_process = run_lmcache_server(port)
# Start KV cache store process
store_process.start()
# Start KV cache retrieve process
retrieve_process.start()
# Clean up the processes
store_process.join()
retrieve_process.terminate()
lmcache_server_process.terminate()
lmcache_server_process.wait()
if __name__ == "__main__":
main()

View File

@ -100,3 +100,8 @@ KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector")
KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
"LMCacheConnectorV1")

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
self._lmcache_engine.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
**kwargs)
def wait_for_save(self):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
self._lmcache_engine.wait_for_save()
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> int:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._lmcache_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
self._lmcache_engine.update_state_after_alloc(request,
num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)