vllm/vllm/v1/engine/mm_input_mapper.py
Roger Wang 7379b3d4b2
[V1] Fix multimodal profiling for Molmo (#11325)
Signed-off-by: ywang96 <ywang@example.com>
Co-authored-by: ywang96 <ywang@example.com>
2024-12-19 16:27:22 +00:00

206 lines
6.9 KiB
Python

from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
from vllm.config import ModelConfig
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
logger = init_logger(__name__)
# The idea of MM preprocessor caching is based on having a client and a server,
# where the client executes in the frontend process (=P0) and the server in the
# core process (=P1).
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Server: Performs caching of the results
#
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes.
# Both Client and Server must use the same cache size
# (to perform mirrored caching)
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256
class MMInputMapperClient:
def __init__(
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0
self.mm_cache_total = 0
def cache_hit_ratio(self, steps):
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)
# TODO: Support modalities beyond image.
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
num_inputs = len(image_inputs)
else:
num_inputs = len(precomputed_mm_inputs)
# Sanity
if self.use_cache:
assert mm_hashes is not None
assert num_inputs == len(mm_hashes)
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_input = None
if self.use_cache:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
self.mm_cache_total += 1
if mm_input is None:
if precomputed_mm_inputs is not None:
# Reuse precomputed input (for merged preprocessor)
mm_input = precomputed_mm_inputs[input_id]
else:
# Apply MM mapper
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[input_id]]},
mm_processor_kwargs=mm_processor_kwargs,
)
if self.use_cache:
# Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input)
else:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
ret_inputs.append(mm_input)
return ret_inputs
class MMInputMapperServer:
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[str],
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
return mm_inputs
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None
else:
self.mm_cache.put(mm_hash, mm_input)
full_mm_inputs.append(mm_input)
return full_mm_inputs
class MMHasher:
def __init__(self):
pass
def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""
if mm_data is None:
return None
image_inputs = mm_data['image']
# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")
return self.hash_images(image_inputs)
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
image_inputs = mm_data["image"]
return self.hash_images(image_inputs)
def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)
# Convert image to bytes
bytes = image.tobytes()
# Hash image bytes
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())
return ret