[V1] VLM preprocessor hashing (#11020)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Alexander Matveev <alexm@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Alexander Matveev 2024-12-11 19:55:30 -05:00 committed by GitHub
parent 452a723bf2
commit 4e11683368
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 332 additions and 48 deletions

View File

@ -5,6 +5,8 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
import random
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@ -23,7 +25,9 @@ def run_llava(question: str, modality: str):
prompt = f"USER: <image>\n{question}\nASSISTANT:" prompt = f"USER: <image>\n{question}\nASSISTANT:"
llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str):
assert modality == "image" assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]" prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str):
assert modality == "video" assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:" prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192) llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -61,7 +69,8 @@ def run_llava_onevision(question: str, modality: str):
<|im_start|>assistant\n" <|im_start|>assistant\n"
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf", llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384) max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -71,7 +80,10 @@ def run_fuyu(question: str, modality: str):
assert modality == "image" assert modality == "image"
prompt = f"{question}\n" prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2) llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -107,6 +119,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2, max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls # Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16}, mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -118,7 +131,8 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA # PaliGemma has special prompt format for VQA
prompt = "caption en" prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224") llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -128,7 +142,9 @@ def run_chameleon(question: str, modality: str):
assert modality == "image" assert modality == "image"
prompt = f"{question}<image>" prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b", max_model_len=4096) llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -154,6 +170,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
trust_remote_code=True, trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
# NOTE The stop_token_ids are different for various versions of MiniCPM-V # NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0 # 2.0
@ -186,6 +203,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
max_model_len=8192, max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -211,6 +229,7 @@ def run_internvl(question: str, modality: str):
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -241,6 +260,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
tensor_parallel_size=4, tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -260,7 +280,8 @@ def run_blip2(question: str, modality: str):
# BLIP-2 prompt format is inaccurate on HuggingFace model repository. # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:" prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b") llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -274,6 +295,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True, trust_remote_code=True,
max_model_len=1024, max_model_len=1024,
max_num_seqs=2, max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
prompt = f"{question}Picture 1: <img></img>\n" prompt = f"{question}Picture 1: <img></img>\n"
@ -296,6 +318,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28, "min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28, "max_pixels": 1280 * 28 * 28,
}, },
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
@ -315,6 +338,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM( llm = LLM(
model=model_name, model=model_name,
max_model_len=8192, max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
prompt = f"<s>[INST]{question}\n[IMG][/INST]" prompt = f"<s>[INST]{question}\n[IMG][/INST]"
@ -338,6 +362,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096, max_model_len=4096,
max_num_seqs=16, max_num_seqs=16,
enforce_eager=True, enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
prompt = f"<|image|><|begin_of_text|>{question}" prompt = f"<|image|><|begin_of_text|>{question}"
@ -355,6 +380,7 @@ def run_molmo(question, modality):
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
prompt = question prompt = question
@ -371,7 +397,8 @@ def run_glm4v(question: str, modality: str):
max_model_len=2048, max_model_len=2048,
max_num_seqs=2, max_num_seqs=2,
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True) enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = question prompt = question
stop_token_ids = [151329, 151336, 151338] stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -394,6 +421,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364 "longest_edge": 3 * 364
}, },
}, },
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
prompt = ( prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:" f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
@ -410,7 +438,8 @@ def run_aria(question: str, modality: str):
llm = LLM(model=model_name, llm = LLM(model=model_name,
tokenizer_mode="slow", tokenizer_mode="slow",
trust_remote_code=True, trust_remote_code=True,
dtype="bfloat16") dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor)
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}" prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n") "<|im_end|>\n<|im_start|>assistant\n")
@ -430,6 +459,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3", model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096, max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor,
) )
stop_token_ids = [128009] stop_token_ids = [128009]
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
@ -494,6 +524,35 @@ def get_multi_modal_input(args):
raise ValueError(msg) raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob]
inputs = []
cur_image = data
for i in range(num_prompts):
if image_repeat_prob is not None:
res = random.choices(no_yes, probs)[0]
if res == 0:
# No repeat => Modify one pixel
cur_image = cur_image.copy()
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)
inputs.append({
"prompt": prompt,
"multi_modal_data": {
modality: cur_image
}
})
return inputs
def main(args): def main(args):
model = args.model_type model = args.model_type
if model not in model_example_map: if model not in model_example_map:
@ -524,14 +583,29 @@ def main(args):
else: else:
# Batch inference # Batch inference
inputs = [{ if args.image_repeat_prob is not None:
"prompt": prompt, # Repeat images with specified probability of "image_repeat_prob"
"multi_modal_data": { inputs = apply_image_repeat(args.image_repeat_prob,
modality: data args.num_prompts, data, prompt,
}, modality)
} for _ in range(args.num_prompts)] else:
# Use the same image for all prompts
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]
outputs = llm.generate(inputs, sampling_params=sampling_params) if args.time_generate:
import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))
else:
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
@ -561,5 +635,23 @@ if __name__ == "__main__":
type=int, type=int,
default=16, default=16,
help='Number of frames to extract from the video.') help='Number of frames to extract from the video.')
parser.add_argument(
'--image-repeat-prob',
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')
parser.add_argument(
'--time-generate',
action='store_true',
help='If True, then print the total generate() call time')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -3,6 +3,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0 numpy < 2.0.0
requests >= 2.26.0 requests >= 2.26.0
tqdm tqdm
blake3
py-cpuinfo py-cpuinfo
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL. transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3. tokenizers >= 0.19.1 # Required for Llama 3.

View File

@ -28,6 +28,7 @@ def make_request() -> EngineCoreRequest:
prompt=PROMPT, prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None, mm_inputs=None,
mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
eos_token_id=None, eos_token_id=None,

View File

@ -30,6 +30,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
prompt=PROMPT, prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None, mm_inputs=None,
mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=params, sampling_params=params,
eos_token_id=None, eos_token_id=None,

View File

@ -147,6 +147,9 @@ class ModelConfig:
HuggingFace config. HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor. for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal
preprocessor/mapper. Otherwise, the mapper executes each time, and
for better performance consider enabling frontend process.
override_neuron_config: Initialize non default neuron config or override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices, override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that this argument will be used to configure the neuron config that
@ -185,6 +188,7 @@ class ModelConfig:
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None: override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model self.model = model
@ -251,6 +255,7 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor
# Set enforce_eager to False if the value is unset. # Set enforce_eager to False if the value is unset.
if self.enforce_eager is None: if self.enforce_eager is None:
@ -2686,9 +2691,10 @@ class VllmConfig:
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, " f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r}," f"pooler_config={self.model_config.pooler_config!r}, "
f" compilation_config={self.compilation_config!r}") f"compilation_config={self.compilation_config!r}")
_current_vllm_config: Optional[VllmConfig] = None _current_vllm_config: Optional[VllmConfig] = None

View File

@ -143,6 +143,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False enable_lora_bias: bool = False
max_loras: int = 1 max_loras: int = 1
@ -593,6 +594,12 @@ class EngineArgs:
type=json.loads, type=json.loads,
help=('Overrides for the multimodal input mapping/processing, ' help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.')) 'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If true, then enables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time'
', and for better performance consider enabling frontend process.')
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', parser.add_argument('--enable-lora',
@ -965,6 +972,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
) )

View File

@ -35,7 +35,8 @@ class EngineCoreRequest:
# always be tokenized? # always be tokenized?
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: List[int] prompt_token_ids: List[int]
mm_inputs: Optional[List[MultiModalKwargs]] mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_placeholders: Optional[MultiModalPlaceholderDict] mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams sampling_params: SamplingParams
eos_token_id: Optional[int] eos_token_id: Optional[int]

View File

@ -18,7 +18,7 @@ from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest, EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
@ -55,9 +55,6 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
# Setup scheduler. # Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config, self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config, vllm_config.cache_config,
@ -65,6 +62,8 @@ class EngineCore:
self._last_logging_time = time.time() self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer()
def _initialize_kv_caches(self, def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]: cache_config: CacheConfig) -> Tuple[int, int]:
start = time.time() start = time.time()
@ -88,7 +87,18 @@ class EngineCore:
def add_request(self, request: EngineCoreRequest): def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler.""" """Add request to the scheduler."""
if request.mm_hashes is not None:
# Here, if hash exists for an image, then it will be fetched
# from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
self.scheduler.add_request(req) self.scheduler.add_request(req)
def abort_requests(self, request_ids: List[str]): def abort_requests(self, request_ids: List[str]):

View File

@ -1,11 +1,35 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry) MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
logger = init_logger(__name__)
# The idea of MM preprocessor caching is based on having a client and a server,
# where the client executes in the frontend process (=P0) and the server in the
# core process (=P1).
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Server: Performs caching of the results
#
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes.
# Both Client and Server must use the same cache size
# (to perform mirrored caching)
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256
class MMInputMapper: class MMInputMapperClient:
def __init__( def __init__(
self, self,
@ -18,23 +42,131 @@ class MMInputMapper:
model_config) model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config) self.mm_registry.init_mm_limits_per_prompt(model_config)
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0
self.mm_cache_total = 0
def cache_hit_ratio(self, steps) -> float:
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)
def process_inputs( def process_inputs(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]], mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> List[MultiModalKwargs]: ) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
num_inputs = len(image_inputs)
else:
num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled
use_hash = mm_hashes is not None
if use_hash:
assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None
if use_hash:
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
self.mm_cache_total += 1
if mm_input is None:
if precomputed_mm_inputs is not None:
# Reuse precomputed input (for merged preprocessor)
mm_input = precomputed_mm_inputs[input_id]
else:
# Apply MM mapper
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[input_id]]},
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
# Add to cache
self.mm_cache.put(mm_hash, mm_input)
else:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
if use_hash:
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)
return ret_inputs, ret_hashes
class MMInputMapperServer:
def __init__(self, ):
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
def process_inputs(
self,
mm_inputs: List[Optional[MultiModalKwargs]],
mm_hashes: List[Optional[str]],
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None
else:
self.mm_cache.put(mm_hash, mm_input)
full_mm_inputs.append(mm_input)
return full_mm_inputs
class MMHasher:
def __init__(self):
pass
def hash(self, prompt: PromptType) -> Optional[List[str]]:
if "multi_modal_data" not in prompt:
return None
mm_data = prompt["multi_modal_data"]
image_inputs = mm_data["image"] image_inputs = mm_data["image"]
if not isinstance(image_inputs, list): if not isinstance(image_inputs, list):
image_inputs = [image_inputs] image_inputs = [image_inputs]
assert len(image_inputs) > 0
# Process each image input separately so that later we can schedule ret = []
# them in a fine-grained manner. for image in image_inputs:
mm_inputs: List[MultiModalKwargs] = [] assert isinstance(image, PIL.Image.Image)
num_images = len(image_inputs)
for i in range(num_images): # Convert image to bytes
mm_input = self.multi_modal_input_mapper( bytes = image.tobytes()
{"image": image_inputs[i]},
mm_processor_kwargs=mm_processor_kwargs, # Hash image bytes
) hasher = blake3()
mm_inputs.append(mm_input) hasher.update(bytes)
return mm_inputs ret.append(hasher.hexdigest())
return ret

View File

@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapper from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
class Processor: class Processor:
@ -42,7 +42,11 @@ class Processor:
model_config) model_config)
# Multi-modal (huggingface) input mapper # Multi-modal (huggingface) input mapper
self.mm_input_mapper = MMInputMapper(model_config) self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None
# TODO: run in an ThreadpoolExecutor or BackgroundProcess. # TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the # This ideally should releases the GIL, so we should not block the
@ -71,6 +75,11 @@ class Processor:
assert priority == 0, "vLLM V1 does not support priority at the moment." assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet." assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Compute MM hashes (if enabled)
mm_hashes = None
if self.mm_hasher is not None:
mm_hashes = self.mm_hasher.hash(prompt)
# Process inputs. # Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
@ -101,16 +110,17 @@ class Processor:
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
# Preprocess multi-modal data # For merged preprocessor, mm_data is already mm_inputs
if len(decoder_inputs.multi_modal_data) == 0: precomputed_mm_inputs = None
mm_inputs = None if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
mm_inputs = [decoder_inputs.multi_modal_data]
else: # Apply MM mapper
mm_inputs = self.mm_input_mapper.process_inputs( mm_inputs = None
decoder_inputs.multi_modal_data, if len(decoder_inputs.multi_modal_data) > 0:
decoder_inputs.mm_processor_kwargs, mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
) decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
# Make Request for Detokenizer. # Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest( detokenizer_request = DetokenizerRequest(
@ -130,6 +140,7 @@ class Processor:
decoder_inputs.prompt, decoder_inputs.prompt,
decoder_inputs.prompt_token_ids, decoder_inputs.prompt_token_ids,
mm_inputs, mm_inputs,
mm_hashes,
decoder_inputs.multi_modal_placeholders, decoder_inputs.multi_modal_placeholders,
sampling_params, sampling_params,
eos_token_id, eos_token_id,

View File

@ -1,3 +1,4 @@
from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Generic, Iterator, List, TypeVar, overload from typing import Any, Generic, Iterator, List, TypeVar, overload
@ -93,3 +94,23 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
finally: finally:
ctx.destroy(linger=0) ctx.destroy(linger=0)
class LRUDictCache:
def __init__(self, size: int):
self.cache = OrderedDict()
self.size = size
def get(self, key, default=None):
if key not in self.cache:
return default
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)