diff --git a/examples/offline_inference/cpu_offload_lmcache.py b/examples/offline_inference/cpu_offload_lmcache.py new file mode 100644 index 000000000000..8211629b24ec --- /dev/null +++ b/examples/offline_inference/cpu_offload_lmcache.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of cpu offloading +with LMCache. + +Note that `pip install lmcache` is needed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import time + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Enable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "True" +# Set local CPU memory limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + +# This example script runs two requests with a shared prefix. +shared_prompt = "Hello, how are you?" * 1000 +first_prompt = [ + shared_prompt + "Hello, my name is", +] +second_prompt = [ + shared_prompt + "Tell me a very long story", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + +ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') +# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB +# memory. Reduce the value if your GPU has less memory. +# Note that LMCache is not compatible with chunked prefill for now. +llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + enable_chunked_prefill=False, + gpu_memory_utilization=0.8) + +outputs = llm.generate(first_prompt, sampling_params) +for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") +print("First request done.") + +time.sleep(1) + +outputs = llm.generate(second_prompt, sampling_params) +for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") +print("Second request done.") + +# Clean up lmcache backend +LMCacheEngineBuilder.destroy(ENGINE_NAME) diff --git a/examples/offline_inference/disaggregated_prefill_lmcache.py b/examples/offline_inference/disaggregated_prefill_lmcache.py new file mode 100644 index 000000000000..36d343c6812e --- /dev/null +++ b/examples/offline_inference/disaggregated_prefill_lmcache.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of disaggregated prefilling +with LMCache. +We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode), +and launch an additional LMCache server. +KV cache is transferred in the following manner: +VLLM prefill node -> LMCache server -> VLLM decode node. + +Note that `pip install lmcache` is needed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# The port to start LMCache server +port = 8100 +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server +os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" + + +def run_prefill(prefill_done, prompts): + # We use GPU 0 for prefill node. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' + ) + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + #llm.generate(prompts, sampling_params) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print("Prefill node is finished.") + prefill_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_decode(prefill_done, prompts, timeout=1): + # We use GPU 1 for decode node. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' + ) + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + print("Waiting for prefill node to finish...") + prefill_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + server_proc = subprocess.Popen([ + "python", "-m", "lmcache.experimental.server", "localhost", + str(port) + ]) + return server_proc + + +if __name__ == "__main__": + + prompts = [ + "Hello, how are you?" * 1000, + ] + + prefill_done = Event() + prefill_process = Process(target=run_prefill, args=(prefill_done, prompts)) + decode_process = Process(target=run_decode, args=(prefill_done, prompts)) + lmcache_server_process = run_lmcache_server(port) + + # Start prefill node + prefill_process.start() + + # Start decode node + decode_process.start() + + # Clean up the processes + decode_process.join() + prefill_process.terminate() + lmcache_server_process.terminate() + lmcache_server_process.wait() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index fe480533458b..7336c54ec8a3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -48,3 +48,8 @@ KVConnectorFactory.register_connector( "MooncakeConnector", "vllm.distributed.kv_transfer.kv_connector.simple_connector", "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py new file mode 100644 index 000000000000..bf9117133af5 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LMCache KV Cache Connector for Distributed Machine Learning Inference + +The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker +(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; +(2) offload and share KV caches. +""" + +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class LMCacheConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.transfer_config = config.kv_transfer_config + self.vllm_config = config + + from lmcache.experimental.cache_engine import LMCacheEngineBuilder + from lmcache.integration.vllm.utils import ENGINE_NAME + from lmcache.integration.vllm.vllm_adapter import ( + RetrieveStatus, StoreStatus, init_lmcache_engine, + lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv) + logger.info("Initializing LMCacheConfig under kv_transfer_config %s", + self.transfer_config) + + # TODO (Jiayi): Find model_config, parallel_config, and cache_config + self.engine = init_lmcache_engine(config.model_config, + config.parallel_config, + config.cache_config) + self.lmcache_engine_name = ENGINE_NAME + self.lmcache_engine_builder = LMCacheEngineBuilder + + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + self.lmcache_retrieve_kv = lmcache_retrieve_kv + self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_store = lmcache_should_store + self.store_status = StoreStatus + self.retrieve_status = RetrieveStatus + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + hidden_or_intermediate_states = None + + # TODO (Jiayi): Need to support chunked prefill + retrieve_status = self.retrieve_status.PREFILL + + model_input, bypass_model_exec = self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + num_reqs = 0 + seq_group_list = model_input.sampling_metadata.seq_groups + assert seq_group_list is not None + for seq_group in seq_group_list: + seq_ids = seq_group.seq_ids + for seq_id in seq_ids: + num_reqs += 1 + + # TODO (Jiayi): Only normal prefill is supported for now + store_status = self.lmcache_should_store(model_input) + self.lmcache_store_kv( + self.model_config, + self.parallel_config, + self.cache_config, + model_executable, + model_input, + kv_caches, + store_status, + ) + + def close(self): + self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 83484cd73550..86166dd5bb83 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -962,8 +962,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: return if all([ - vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER - is None + vllm_config.kv_transfer_config.is_kv_transfer_instance, + _KV_TRANSFER is None ]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank,