mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 05:45:14 +08:00
[V1] Remove input cache client (#14864)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
8d6cf89526
commit
b539222d4e
@ -379,6 +379,7 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
@ -401,6 +402,7 @@ class InputPreprocessor:
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
@ -431,6 +433,7 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
@ -452,6 +455,7 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
@ -726,6 +730,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
@ -746,6 +751,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
|
||||
@ -52,7 +52,7 @@ class EngineCoreRequest(
|
||||
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: Optional[list[Optional[MultiModalKwargs]]]
|
||||
mm_inputs: Optional[list[MultiModalKwargs]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: SamplingParams
|
||||
|
||||
@ -1,131 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# The idea of multimodal preprocessing caching is based on having a client and
|
||||
# a server, where the client executes in the frontend process (=P0) and the
|
||||
# server in the core process (=P1).
|
||||
#
|
||||
# -- Client:
|
||||
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
|
||||
# - Perform caching of the generated MultiModalKwargs.
|
||||
# - This client can be deprecated once all mutimodal models migrate to use
|
||||
# merged preprocessor with built-in caching functionality.
|
||||
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
|
||||
# with built-in caching functionality, with mm_hash as its identifier.
|
||||
#
|
||||
# -- Server:
|
||||
# - Perform caching of the received MultiModalKwargs.
|
||||
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
|
||||
#
|
||||
# The caching for both client and server is mirrored/similar, and this allows us
|
||||
# The caching for both client and server is mirrored, and this allows us
|
||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||
# client (=P0) and server (=P1) processes.
|
||||
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
|
||||
# cache.
|
||||
|
||||
# Both Client and Server must use the same cache size
|
||||
# (to perform mirrored caching). This cache size is set by the environment
|
||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||
|
||||
|
||||
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
|
||||
# merged preprocessor with built-in caching functionality.
|
||||
class MMInputCacheClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
|
||||
model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
# Init cache
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.mm_debug_cache_hit_ratio_steps = None
|
||||
self.mm_debug_cache_hits = 0
|
||||
self.mm_debug_cache_total = 0
|
||||
|
||||
def cache_hit_ratio(self, steps):
|
||||
total = self.mm_debug_cache_total
|
||||
|
||||
if total > 0 and total % steps == 0:
|
||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||
self.mm_debug_cache_hits / total)
|
||||
|
||||
# NOTE: process_inputs only supports image inputs since all multimodal
|
||||
# models with other modalities have migrated to use merged preprocessor.
|
||||
def process_inputs(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_hashes: Optional[list[str]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
|
||||
) -> list[Optional[MultiModalKwargs]]:
|
||||
if precomputed_mm_inputs is None:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
image_inputs = [image_inputs]
|
||||
num_inputs = len(image_inputs)
|
||||
else:
|
||||
num_inputs = len(precomputed_mm_inputs)
|
||||
|
||||
# Sanity
|
||||
if self.use_cache:
|
||||
assert mm_hashes is not None
|
||||
assert num_inputs == len(mm_hashes)
|
||||
|
||||
# Process each image input separately, so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||
ret_inputs: list[Optional[MultiModalKwargs]] = []
|
||||
for input_id in range(num_inputs):
|
||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
|
||||
|
||||
mm_input = None
|
||||
if self.use_cache:
|
||||
assert mm_hashes is not None
|
||||
mm_hash = mm_hashes[input_id]
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
|
||||
self.mm_debug_cache_total += 1
|
||||
if mm_input is None:
|
||||
if precomputed_mm_inputs is not None:
|
||||
# Reuse precomputed input (for merged preprocessor)
|
||||
mm_input = precomputed_mm_inputs[input_id]
|
||||
else:
|
||||
# Apply legacy input_mapper
|
||||
mm_input = self.multi_modal_input_mapper(
|
||||
{"image": [image_inputs[input_id]]},
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if self.use_cache:
|
||||
# Add to cache
|
||||
assert mm_hash is not None
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
else:
|
||||
self.mm_debug_cache_hits += 1
|
||||
mm_input = None # Avoids sending mm_input to Server
|
||||
|
||||
ret_inputs.append(mm_input)
|
||||
|
||||
return ret_inputs
|
||||
|
||||
|
||||
class MMInputCacheServer:
|
||||
|
||||
def __init__(self, model_config):
|
||||
@ -135,9 +34,9 @@ class MMInputCacheServer:
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_inputs: list[Optional[MultiModalKwargs]],
|
||||
mm_inputs: list[MultiModalKwargs],
|
||||
mm_hashes: list[str],
|
||||
) -> list[Optional[MultiModalKwargs]]:
|
||||
) -> list[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
@ -147,8 +46,7 @@ class MMInputCacheServer:
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
assert mm_hash is not None
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
assert mm_input is not None
|
||||
mm_input = self.mm_cache[mm_hash]
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
|
||||
@ -11,15 +11,15 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.structured_output.utils import validate_structured_output_request
|
||||
|
||||
|
||||
@ -45,11 +45,6 @@ class Processor:
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
self.model_config)
|
||||
|
||||
# Multi-modal (huggingface) input mapper
|
||||
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = (
|
||||
@ -171,7 +166,7 @@ class Processor:
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
@ -180,10 +175,6 @@ class Processor:
|
||||
)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
# Process prompt and prompt token ids.
|
||||
# Only applicable to multimodal models with legacy input processor.
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
@ -212,36 +203,22 @@ class Processor:
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
# Multimodal related.
|
||||
# Compute MM hashes (if enabled)
|
||||
mm_hashes = None
|
||||
if self.use_hash:
|
||||
# Use mm_hashes from processed inputs if the model has merged
|
||||
# input processor.
|
||||
if decoder_inputs.multi_modal_hashes:
|
||||
mm_hashes = decoder_inputs.multi_modal_hashes
|
||||
# Fallback to using MultiModalHasher directly.
|
||||
else:
|
||||
mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)
|
||||
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if (decoder_mm_inputs := decoder_inputs.multi_modal_data):
|
||||
assert isinstance(decoder_mm_inputs, MultiModalKwargs)
|
||||
|
||||
# For merged preprocessor, mm_data is already mm_inputs
|
||||
precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
||||
decoder_mm_data = decoder_inputs.multi_modal_data
|
||||
if isinstance(decoder_mm_data, MultiModalKwargs):
|
||||
# The output of merged multi-modal processor (`decoder_mm_data`)
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# contains the kwargs for all items from all modalities.
|
||||
# This code separates them so that there is one set of kwargs
|
||||
# per item per modality.
|
||||
precomputed_mm_inputs = [
|
||||
individual_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item])
|
||||
for modality in decoder_mm_data.modalities
|
||||
for item in decoder_mm_data.get_items(modality)
|
||||
for modality in decoder_mm_inputs.modalities
|
||||
for item in decoder_mm_inputs.get_items(modality)
|
||||
]
|
||||
|
||||
mm_positions = decoder_inputs.multi_modal_placeholders
|
||||
|
||||
# Last-mile processing of multimodal metadata and inputs.
|
||||
if mm_positions:
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
@ -251,14 +228,13 @@ class Processor:
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
mm_positions,
|
||||
mm_hashes,
|
||||
decoder_inputs.multi_modal_placeholders,
|
||||
decoder_inputs.multi_modal_hashes if self.use_hash else None,
|
||||
)
|
||||
|
||||
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
|
||||
# modalities involved AND the model supports merged input processor.
|
||||
if len(sorted_modalities) > 1 and precomputed_mm_inputs:
|
||||
|
||||
# modalities involved.
|
||||
if len(sorted_modalities) > 1:
|
||||
modality_order_dict = {
|
||||
modality: order
|
||||
for order, modality in enumerate(sorted_modalities)
|
||||
@ -266,26 +242,16 @@ class Processor:
|
||||
|
||||
# Sanity check to make sure each multimodal input has only one
|
||||
# modality key.
|
||||
for mm_input in precomputed_mm_inputs:
|
||||
for mm_input in individual_mm_inputs:
|
||||
assert len(mm_input.modalities) == 1
|
||||
|
||||
# Sort MultiModalKwags to match sorted_mm_positions
|
||||
precomputed_mm_inputs = sorted(
|
||||
precomputed_mm_inputs,
|
||||
# Sort MultiModalKwargs to match sorted_mm_positions
|
||||
sorted_mm_inputs = sorted(
|
||||
individual_mm_inputs,
|
||||
key=lambda mm_input: modality_order_dict[list(
|
||||
mm_input.modalities)[0]])
|
||||
|
||||
# Apply mm input cache update and legacy input mapper if one exists.
|
||||
sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
|
||||
mm_data=decoder_mm_data,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
|
||||
precomputed_mm_inputs=precomputed_mm_inputs,
|
||||
)
|
||||
else:
|
||||
sorted_mm_inputs = None
|
||||
sorted_mm_hashes = None
|
||||
sorted_mm_positions = None
|
||||
else:
|
||||
sorted_mm_inputs = individual_mm_inputs
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
|
||||
@ -29,7 +29,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
@ -133,14 +132,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE: Initialized client is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
# Only applicable to multimodal models with legacy input mapper.
|
||||
self.mm_input_mapper_profiling = MMInputCacheClient(
|
||||
self.model_config)
|
||||
self.mm_input_mapper_profiling.use_cache = False
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=model_config,
|
||||
scheduler_config=scheduler_config,
|
||||
@ -1376,32 +1367,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_registry=self.mm_registry,
|
||||
)
|
||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||
if not isinstance(dummy_mm_data, MultiModalKwargs):
|
||||
# TODO: Delete this check once input mapper is fully removed.
|
||||
raise RuntimeError(
|
||||
"Legacy input mapper is not supported in V1")
|
||||
|
||||
# Dummy data definition in V0 may contain multiple multimodal items
|
||||
# Dummy data definition may contain multiple multimodal items
|
||||
# (e.g, multiple images) for a single request, therefore here we
|
||||
# always replicate first item by max_num_mm_items times since in V1
|
||||
# they are scheduled to be processed separately.
|
||||
|
||||
# Case when models have a merged processor, their dummy data is
|
||||
# already batched `MultiModalKwargs`, therefore we take the first
|
||||
# `MultiModalKwargsItem` from the desired modality to profile on.
|
||||
if isinstance(dummy_mm_data, MultiModalKwargs):
|
||||
dummy_mm_item = dummy_mm_data.get_item(
|
||||
modality=dummy_data_modality, item_index=0)
|
||||
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||
|
||||
# Case when models have dummy data explicitly defined as
|
||||
# `MultiModalDataDict`, so they need to be processed through input
|
||||
# mapper.
|
||||
# TODO (ywang96): deprecate this path once merged processor is
|
||||
# supported on all models.
|
||||
else:
|
||||
mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
|
||||
mm_data=dummy_mm_data,
|
||||
mm_hashes=None,
|
||||
mm_processor_kwargs=None,
|
||||
precomputed_mm_inputs=None)
|
||||
dummy_mm_kwargs = mm_kwargs_list[0]
|
||||
dummy_mm_item = dummy_mm_data.get_item(
|
||||
modality=dummy_data_modality, item_index=0)
|
||||
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user