[KVConnector] Migrate the LMCache integration code to be vLLM native (#25542)

Signed-off-by: ApostaC <yihua98@uchicago.edu>
This commit is contained in:
Yihua Cheng 2025-10-24 17:23:53 -07:00 committed by GitHub
parent 269c4db0a4
commit 83f478bb19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1637 additions and 2 deletions

View File

@ -3,7 +3,9 @@
from typing import TYPE_CHECKING, Any
import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -11,6 +13,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
vllm_v1_adapter as _adapter,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
@ -26,7 +31,18 @@ logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
assert vllm_config.kv_transfer_config is not None
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
"use_native", False
)
if use_native:
logger.info("Initializing native LMCache connector")
cls = _adapter.LMCacheConnectorV1Impl
else:
logger.info("Initializing latest dev LMCache connector")
cls = LMCacheConnectorLatestImpl
self._lmcache_engine = cls(vllm_config, role, self)
# ==============================
# Worker-side methods

View File

@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Standard
import os
import threading
from typing import TYPE_CHECKING, Union
import torch
from lmcache.config import LMCacheEngineConfig as Config
from lmcache.logging import init_logger
from lmcache.v1.config import LMCacheEngineConfig as V1Config
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.request import Request
logger = init_logger(__name__)
ENGINE_NAME = "vllm-instance"
# Thread-safe singleton storage
_config_instance: Config | V1Config | None = None
_config_lock = threading.Lock()
def is_false(value: str) -> bool:
"""Check if the given string value is equivalent to 'false'."""
return value.lower() in ("false", "0", "no", "n", "off")
def lmcache_get_or_create_config() -> Config | V1Config:
"""Get the LMCache configuration from the environment variable
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
function will return the default configuration.
This function is thread-safe and implements singleton pattern,
ensuring the configuration is loaded only once.
"""
global _config_instance
# Double-checked locking for thread-safe singleton
if _config_instance is None:
with _config_lock:
if _config_instance is None: # Check again within lock
if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")):
logger.warning(
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
"Using legacy configuration is deprecated and will "
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
"to True."
)
LMCacheEngineConfig = Config # type: ignore[assignment]
else:
LMCacheEngineConfig = V1Config # type: ignore[assignment]
if "LMCACHE_CONFIG_FILE" not in os.environ:
logger.warning(
"No LMCache configuration file is set. Trying to read"
" configurations from the environment variables."
)
logger.warning(
"You can set the configuration file through "
"the environment variable: LMCACHE_CONFIG_FILE"
)
_config_instance = LMCacheEngineConfig.from_env()
else:
config_file = os.environ["LMCACHE_CONFIG_FILE"]
logger.info("Loading LMCache config file %s", config_file)
_config_instance = LMCacheEngineConfig.from_file(config_file)
# Update config from environment variables
_config_instance.update_config_from_env()
return _config_instance
def hex_hash_to_int16(s: str) -> int:
"""
Convert a hex hash string to a 16-bit integer.
"""
return int(s, 16) & 0xFFFF
def apply_mm_hashes_to_token_ids(
token_ids: torch.Tensor,
mm_hashes: list[str],
mm_positions: list["PlaceholderRange"],
) -> torch.Tensor:
"""
Overwrite token_ids in-place for multimodal placeholders using
efficient slice assignments.
"""
n = token_ids.size(0)
for hash_str, placeholder in zip(mm_hashes, mm_positions):
start, length = placeholder.offset, placeholder.length
if start >= n:
continue
end = min(start + length, n)
token_ids[start:end] = hex_hash_to_int16(hash_str)
return token_ids
def mla_enabled(model_config: "ModelConfig") -> bool:
return (
hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla
)
def create_lmcache_metadata(
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
):
"""
Create LMCacheEngineMetadata from vLLM configuration.
This function extracts common metadata creation logic that was duplicated
across multiple files.
Args:
vllm_config (VllmConfig): vLLM configuration object containing model,
parallel, and cache configs (alternative to
individual config parameters)
model_config (ModelConfig): Model configuration (alternative to
vllm_config)
parallel_config (ParallelConfig): Parallel configuration (alternative
to vllm_config)
cache_config (CacheConfig): Cache configuration (alternative to
vllm_config)
"""
# Third Party
# First Party
from lmcache.config import LMCacheEngineMetadata
from vllm.utils import get_kv_cache_torch_dtype
config = lmcache_get_or_create_config()
# Support both vllm_config object and individual config parameters
if vllm_config is not None:
model_cfg = vllm_config.model_config
parallel_cfg = vllm_config.parallel_config
cache_cfg = vllm_config.cache_config
else:
if model_config is None or parallel_config is None or cache_config is None:
raise ValueError(
"Either vllm_config must be provided, or all of "
"model_config, parallel_config, and cache_config must be provided."
)
model_cfg = model_config
parallel_cfg = parallel_config
cache_cfg = cache_config
# Get KV cache dtype
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
# Check if MLA is enabled
use_mla = mla_enabled(model_cfg)
# Construct KV shape (for memory pool)
num_layer = model_cfg.get_num_layers(parallel_cfg)
chunk_size = config.chunk_size
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
head_size = model_cfg.get_head_size()
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
# Create metadata
metadata = LMCacheEngineMetadata(
model_cfg.model,
parallel_cfg.world_size,
parallel_cfg.rank,
"vllm",
kv_dtype,
kv_shape,
use_mla,
)
return metadata, config
def extract_mm_features(
request: Union["Request", "NewRequestData"], modify: bool = False
) -> tuple[list[str], list["PlaceholderRange"]]:
"""
Normalize multimodal information from a Request into parallel lists.
This helper reads either:
1) `request.mm_features` (objects each exposing `.identifier` and
`.mm_position`), or
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
It returns two equally sized lists: the multimodal hash identifiers and
their corresponding positions. If the request contains no multimodal info,
it returns `([], [])`.
Args:
request (Request): The source object.
modify (bool):
Controls copy semantics for the legacy-path return values.
- If True and legacy fields are used, shallow-copies are returned so
the caller can mutate the lists without affecting `request`.
- If False, the original legacy sequences are returned as-is
(zero-copy); treat them as read-only.
Returns:
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
May be `([], [])` when no multimodal data is present.
"""
if getattr(request, "mm_features", None):
mm_hashes, mm_positions = zip(
*((f.identifier, f.mm_position) for f in request.mm_features)
)
return (list(mm_hashes), list(mm_positions))
elif getattr(request, "mm_hashes", None):
if modify:
return (
request.mm_hashes.copy(), # type: ignore
request.mm_positions.copy(), # type: ignore
)
else:
return (request.mm_hashes, request.mm_positions) # type: ignore
else:
return ([], [])