mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 00:58:38 +08:00
[V1] Clarify input processing and multimodal feature caching logic (#13211)
This commit is contained in:
parent
578087e56c
commit
fdcf64d3c6
@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
|||||||
from vllm.v1.core.scheduler import Scheduler
|
from vllm.v1.core.scheduler import Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
@ -65,7 +65,7 @@ class EngineCore:
|
|||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mm_input_mapper_server = MMInputMapperServer(
|
self.mm_input_cache_server = MMInputCacheServer(
|
||||||
vllm_config.model_config)
|
vllm_config.model_config)
|
||||||
|
|
||||||
def _initialize_kv_caches(self,
|
def _initialize_kv_caches(self,
|
||||||
@ -102,13 +102,13 @@ class EngineCore:
|
|||||||
"""Add request to the scheduler."""
|
"""Add request to the scheduler."""
|
||||||
|
|
||||||
if request.mm_hashes is not None:
|
if request.mm_hashes is not None:
|
||||||
# Here, if hash exists for an image, then it will be fetched
|
# Here, if hash exists for a multimodal input, then it will be
|
||||||
# from the cache, else it will be added to the cache.
|
# fetched from the cache, else it will be added to the cache.
|
||||||
# Note that the cache here is mirrored with the client side of the
|
# Note that the cache here is mirrored with the client cache, so
|
||||||
# MM mapper, so anything that has a hash must have a HIT cache
|
# anything that has a hash must have a HIT cache entry here
|
||||||
# entry here as well.
|
# as well.
|
||||||
assert request.mm_inputs is not None
|
assert request.mm_inputs is not None
|
||||||
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
|
request.mm_inputs = self.mm_input_cache_server.get_and_update(
|
||||||
request.mm_inputs, request.mm_hashes)
|
request.mm_inputs, request.mm_hashes)
|
||||||
|
|
||||||
req = Request.from_engine_core_request(request)
|
req = Request.from_engine_core_request(request)
|
||||||
|
|||||||
@ -10,12 +10,18 @@ from vllm.utils import LRUCache
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# The idea of MM preprocessor caching is based on having a client and a server,
|
# The idea of multimodal preprocessing caching is based on having a client and
|
||||||
# where the client executes in the frontend process (=P0) and the server in the
|
# a server, where the client executes in the frontend process (=P0) and the
|
||||||
# core process (=P1).
|
# server in the core process (=P1).
|
||||||
#
|
#
|
||||||
# -- Client: Executes the MM mapper and performs caching of the results.
|
# -- Client:
|
||||||
# -- Server: Performs caching of the results
|
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
|
||||||
|
# - Perform caching of the generated MultiModalKwargs.
|
||||||
|
# - This client can be deprecated once all mutimodal models migrate to use
|
||||||
|
# merged preprocessor with built-in caching functionality.
|
||||||
|
#
|
||||||
|
# -- Server:
|
||||||
|
# - Perform caching of the received MultiModalKwargs.
|
||||||
#
|
#
|
||||||
# The caching for both client and server is mirrored/similar, and this allows us
|
# The caching for both client and server is mirrored/similar, and this allows us
|
||||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||||
@ -27,7 +33,9 @@ logger = init_logger(__name__)
|
|||||||
MM_CACHE_SIZE = 256
|
MM_CACHE_SIZE = 256
|
||||||
|
|
||||||
|
|
||||||
class MMInputMapperClient:
|
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
|
||||||
|
# merged preprocessor with built-in caching functionality.
|
||||||
|
class MMInputCacheClient:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -54,7 +62,8 @@ class MMInputMapperClient:
|
|||||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||||
self.mm_cache_hits / self.mm_cache_total)
|
self.mm_cache_hits / self.mm_cache_total)
|
||||||
|
|
||||||
# TODO: Support modalities beyond image.
|
# NOTE: process_inputs only supports image inputs since all multimodal
|
||||||
|
# models with other modalities have migrated to use merged preprocessor.
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
@ -95,7 +104,7 @@ class MMInputMapperClient:
|
|||||||
# Reuse precomputed input (for merged preprocessor)
|
# Reuse precomputed input (for merged preprocessor)
|
||||||
mm_input = precomputed_mm_inputs[input_id]
|
mm_input = precomputed_mm_inputs[input_id]
|
||||||
else:
|
else:
|
||||||
# Apply MM mapper
|
# Apply legacy input_mapper
|
||||||
mm_input = self.multi_modal_input_mapper(
|
mm_input = self.multi_modal_input_mapper(
|
||||||
{"image": [image_inputs[input_id]]},
|
{"image": [image_inputs[input_id]]},
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
@ -114,13 +123,13 @@ class MMInputMapperClient:
|
|||||||
return ret_inputs
|
return ret_inputs
|
||||||
|
|
||||||
|
|
||||||
class MMInputMapperServer:
|
class MMInputCacheServer:
|
||||||
|
|
||||||
def __init__(self, model_config):
|
def __init__(self, model_config):
|
||||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||||
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
def process_inputs(
|
def get_and_update(
|
||||||
self,
|
self,
|
||||||
mm_inputs: List[Optional[MultiModalKwargs]],
|
mm_inputs: List[Optional[MultiModalKwargs]],
|
||||||
mm_hashes: List[str],
|
mm_hashes: List[str],
|
||||||
@ -17,7 +17,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||||
|
|
||||||
|
|
||||||
class Processor:
|
class Processor:
|
||||||
@ -46,7 +46,7 @@ class Processor:
|
|||||||
model_config)
|
model_config)
|
||||||
|
|
||||||
# Multi-modal (huggingface) input mapper
|
# Multi-modal (huggingface) input mapper
|
||||||
self.mm_input_mapper_client = MMInputMapperClient(model_config)
|
self.mm_input_cache_client = MMInputCacheClient(model_config)
|
||||||
|
|
||||||
# Multi-modal hasher (for images)
|
# Multi-modal hasher (for images)
|
||||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||||
@ -106,17 +106,25 @@ class Processor:
|
|||||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||||
assert trace_headers is None, "vLLM V1 does not support tracing yet."
|
assert trace_headers is None, "vLLM V1 does not support tracing yet."
|
||||||
|
|
||||||
# Process inputs.
|
# Process inputs, which includes:
|
||||||
|
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||||
|
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||||
|
# multimodal data and expand prompt token ids accordingly.
|
||||||
|
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
)
|
)
|
||||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
|
||||||
self._validate_model_inputs(processed_inputs)
|
|
||||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||||
|
|
||||||
|
# Process prompt and prompt token ids.
|
||||||
|
# Only applicable to multimodal models with legacy input processor.
|
||||||
|
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||||
|
|
||||||
|
self._validate_model_inputs(processed_inputs)
|
||||||
|
|
||||||
if is_encoder_decoder_inputs(processed_inputs):
|
if is_encoder_decoder_inputs(processed_inputs):
|
||||||
decoder_inputs = SingletonInputsAdapter(
|
decoder_inputs = SingletonInputsAdapter(
|
||||||
processed_inputs["decoder"])
|
processed_inputs["decoder"])
|
||||||
@ -200,8 +208,8 @@ class Processor:
|
|||||||
key=lambda mm_input: modality_order_dict[list(
|
key=lambda mm_input: modality_order_dict[list(
|
||||||
mm_input.modalities)[0]])
|
mm_input.modalities)[0]])
|
||||||
|
|
||||||
# Apply mm input cache update (and input mapper if necessary).
|
# Apply mm input cache update and legacy input mapper if one exists.
|
||||||
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs(
|
sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
|
||||||
mm_data=decoder_mm_data,
|
mm_data=decoder_mm_data,
|
||||||
mm_hashes=sorted_mm_hashes,
|
mm_hashes=sorted_mm_hashes,
|
||||||
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
|
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||||
FlashAttentionMetadata)
|
FlashAttentionMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||||
@ -95,9 +95,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = model_config.uses_mrope
|
self.uses_mrope = model_config.uses_mrope
|
||||||
|
|
||||||
# NOTE: Initialized input mapper is only used for processing dummy
|
# NOTE: Initialized client is only used for processing dummy
|
||||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
# Only applicable to multimodal models with legacy input mapper.
|
||||||
|
self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config)
|
||||||
self.mm_input_mapper_profiling.use_cache = False
|
self.mm_input_mapper_profiling.use_cache = False
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user