[V1] Clarify input processing and multimodal feature caching logic (#13211)

This commit is contained in:
Roger Wang 2025-02-13 03:43:24 -08:00 committed by GitHub
parent 578087e56c
commit fdcf64d3c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 28 deletions

View File

@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@ -65,7 +65,7 @@ class EngineCore:
log_stats=self.log_stats, log_stats=self.log_stats,
) )
self.mm_input_mapper_server = MMInputMapperServer( self.mm_input_cache_server = MMInputCacheServer(
vllm_config.model_config) vllm_config.model_config)
def _initialize_kv_caches(self, def _initialize_kv_caches(self,
@ -102,13 +102,13 @@ class EngineCore:
"""Add request to the scheduler.""" """Add request to the scheduler."""
if request.mm_hashes is not None: if request.mm_hashes is not None:
# Here, if hash exists for an image, then it will be fetched # Here, if hash exists for a multimodal input, then it will be
# from the cache, else it will be added to the cache. # fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client side of the # Note that the cache here is mirrored with the client cache, so
# MM mapper, so anything that has a hash must have a HIT cache # anything that has a hash must have a HIT cache entry here
# entry here as well. # as well.
assert request.mm_inputs is not None assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_mapper_server.process_inputs( request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes) request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)

View File

@ -10,12 +10,18 @@ from vllm.utils import LRUCache
logger = init_logger(__name__) logger = init_logger(__name__)
# The idea of MM preprocessor caching is based on having a client and a server, # The idea of multimodal preprocessing caching is based on having a client and
# where the client executes in the frontend process (=P0) and the server in the # a server, where the client executes in the frontend process (=P0) and the
# core process (=P1). # server in the core process (=P1).
# #
# -- Client: Executes the MM mapper and performs caching of the results. # -- Client:
# -- Server: Performs caching of the results # - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
# - Perform caching of the generated MultiModalKwargs.
# - This client can be deprecated once all mutimodal models migrate to use
# merged preprocessor with built-in caching functionality.
#
# -- Server:
# - Perform caching of the received MultiModalKwargs.
# #
# The caching for both client and server is mirrored/similar, and this allows us # The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between # to avoid the serialization of "mm_inputs" (like pixel values) between
@ -27,7 +33,9 @@ logger = init_logger(__name__)
MM_CACHE_SIZE = 256 MM_CACHE_SIZE = 256
class MMInputMapperClient: # TODO(ywang96): Deprecate this class once all multimodal models migrate to use
# merged preprocessor with built-in caching functionality.
class MMInputCacheClient:
def __init__( def __init__(
self, self,
@ -54,7 +62,8 @@ class MMInputMapperClient:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total) self.mm_cache_hits / self.mm_cache_total)
# TODO: Support modalities beyond image. # NOTE: process_inputs only supports image inputs since all multimodal
# models with other modalities have migrated to use merged preprocessor.
def process_inputs( def process_inputs(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
@ -95,7 +104,7 @@ class MMInputMapperClient:
# Reuse precomputed input (for merged preprocessor) # Reuse precomputed input (for merged preprocessor)
mm_input = precomputed_mm_inputs[input_id] mm_input = precomputed_mm_inputs[input_id]
else: else:
# Apply MM mapper # Apply legacy input_mapper
mm_input = self.multi_modal_input_mapper( mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[input_id]]}, {"image": [image_inputs[input_id]]},
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
@ -114,13 +123,13 @@ class MMInputMapperClient:
return ret_inputs return ret_inputs
class MMInputMapperServer: class MMInputCacheServer:
def __init__(self, model_config): def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs( def get_and_update(
self, self,
mm_inputs: List[Optional[MultiModalKwargs]], mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[str], mm_hashes: List[str],

View File

@ -17,7 +17,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_cache import MMInputCacheClient
class Processor: class Processor:
@ -46,7 +46,7 @@ class Processor:
model_config) model_config)
# Multi-modal (huggingface) input mapper # Multi-modal (huggingface) input mapper
self.mm_input_mapper_client = MMInputMapperClient(model_config) self.mm_input_cache_client = MMInputCacheClient(model_config)
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
@ -106,17 +106,25 @@ class Processor:
assert priority == 0, "vLLM V1 does not support priority at the moment." assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet." assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Process inputs. # Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
self._validate_model_inputs(processed_inputs)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
# Process prompt and prompt token ids.
# Only applicable to multimodal models with legacy input processor.
processed_inputs = self.input_processor(preprocessed_inputs)
self._validate_model_inputs(processed_inputs)
if is_encoder_decoder_inputs(processed_inputs): if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter( decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"]) processed_inputs["decoder"])
@ -200,8 +208,8 @@ class Processor:
key=lambda mm_input: modality_order_dict[list( key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]]) mm_input.modalities)[0]])
# Apply mm input cache update (and input mapper if necessary). # Apply mm input cache update and legacy input mapper if one exists.
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs( sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
mm_data=decoder_mm_data, mm_data=decoder_mm_data,
mm_hashes=sorted_mm_hashes, mm_hashes=sorted_mm_hashes,
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs, mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,

View File

@ -27,7 +27,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
@ -95,9 +95,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
# NOTE: Initialized input mapper is only used for processing dummy # NOTE: Initialized client is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling. # multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) # Only applicable to multimodal models with legacy input mapper.
self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False self.mm_input_mapper_profiling.use_cache = False
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( encoder_compute_budget, encoder_cache_size = compute_encoder_budget(