[Misc] refactor example - cpu_offload_lmcache (#17460)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid 2025-05-01 23:05:24 +08:00 committed by GitHub
parent 460a2b1100
commit 7423cf0a9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 69 deletions

View File

@ -44,8 +44,8 @@ The main script generates several log files:
## 2. CPU Offload Examples
- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0
- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1
- `python cpu_offload_lmcache.py -v v0` - CPU offloading implementation for vLLM v0
- `python cpu_offload_lmcache.py -v v1` - CPU offloading implementation for vLLM v1
## 3. KV Cache Sharing

View File

@ -1,22 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache.
with LMCache in vLLM v1 or v0.
Usage:
Specify vLLM version
-v v0 : Use LMCacheConnector
model = mistralai/Mistral-7B-Instruct-v0.2
(Includes enable_chunked_prefill = True)
-v v1 : Use LMCacheConnectorV1 (default)
model = meta-llama/Meta-Llama-3.1-8B-Instruct
(Without enable_chunked_prefill)
Note that `lmcache` is needed to run this example.
Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
"""
import argparse
import contextlib
import os
import time
from dataclasses import asdict
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.engine.arg_utils import EngineArgs
def setup_environment_variables():
@ -32,18 +47,32 @@ def setup_environment_variables():
@contextlib.contextmanager
def build_llm_with_lmcache():
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
def build_llm_with_lmcache(lmcache_connector: str, model: str,
vllm_version: str):
ktc = KVTransferConfig(
kv_connector=lmcache_connector,
kv_role="kv_both",
)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
enable_chunked_prefill=True,
gpu_memory_utilization=0.8)
if vllm_version == "v0":
llm_args = EngineArgs(
model=model,
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enable_chunked_prefill=True, # Only in v0
)
else:
llm_args = EngineArgs(
model=model,
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
)
llm = LLM(**asdict(llm_args))
try:
yield llm
finally:
@ -57,6 +86,9 @@ def print_output(
sampling_params: SamplingParams,
req_str: str,
):
# Should be able to see logs like the following:
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
# This indicates that the KV cache has been stored in LMCache.
start = time.time()
outputs = llm.generate(prompt, sampling_params)
print("-" * 50)
@ -68,10 +100,29 @@ def print_output(
print("-" * 50)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-v",
"--version",
choices=["v0", "v1"],
default="v1",
help="Specify vLLM version (default: v1)")
return parser.parse_args()
def main():
args = parse_args()
if args.version == "v0":
lmcache_connector = "LMCacheConnector"
model = "mistralai/Mistral-7B-Instruct-v0.2"
else:
lmcache_connector = "LMCacheConnectorV1"
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
setup_environment_variables()
with build_llm_with_lmcache() as llm:
with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
# This example script runs two requests with a shared prefix.
# Define the shared prompt and specific prompts

View File

@ -1,57 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache in vLLM v1.
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# LMCache-related environment variables
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Enable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "True"
# Set local CPU memory limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# This example script runs two requests with a shared prefix.
shared_prompt = "Hello, how are you?" * 1000
first_prompt = [
shared_prompt + "Hello, my name is",
]
second_prompt = [
shared_prompt + "Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note that LMCache is not compatible with chunked prefill for now.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8)
# Should be able to see logs like the following:
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
# This indicates that the KV cache has been stored in LMCache.
outputs = llm.generate(first_prompt, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)