From 7c65527918cd16286961a2a779e15743ca41ab0e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 12 Nov 2024 08:57:14 -0800 Subject: [PATCH] [V1] Use pickle for serializing EngineCoreRequest & Add multimodal inputs to EngineCoreRequest (#10245) Signed-off-by: Woosuk Kwon --- vllm/v1/engine/__init__.py | 9 +++++++-- vllm/v1/engine/core.py | 3 ++- vllm/v1/engine/core_client.py | 3 ++- vllm/v1/engine/processor.py | 5 ++++- vllm/v1/serial_utils.py | 10 ++++++++++ 5 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 vllm/v1/serial_utils.py diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 8bc16651faf9..edfb8bd7c2fc 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,10 +1,11 @@ import enum from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import msgspec from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -22,7 +23,8 @@ class DetokenizerRequest: include_stop_str_in_output: bool -class EngineCoreRequest(msgspec.Struct, omit_defaults=True): +@dataclass +class EngineCoreRequest: # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # but this object is currently not playing well with msgspec @@ -33,6 +35,9 @@ class EngineCoreRequest(msgspec.Struct, omit_defaults=True): # always be tokenized? prompt: Optional[str] prompt_token_ids: List[int] + mm_data: Optional[MultiModalDataDict] + mm_placeholders: Optional[MultiModalPlaceholderDict] + mm_processor_kwargs: Optional[Dict[str, Any]] sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f9d3473d0131..808c3936b6c3 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -19,6 +19,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.request import Request, RequestStatus +from vllm.v1.serial_utils import PickleEncoder from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -315,7 +316,7 @@ class EngineCoreProc(EngineCore): """Input socket IO thread.""" # Msgpack serialization decoding. - decoder_add_req = msgpack.Decoder(EngineCoreRequest) + decoder_add_req = PickleEncoder() decoder_abort_req = msgpack.Decoder(list[str]) with self.make_socket(input_path, zmq.constants.PULL) as socket: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f9e4677fb8c5..09801e20e16c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -11,6 +11,7 @@ from vllm.utils import get_open_zmq_ipc_path from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.serial_utils import PickleEncoder logger = init_logger(__name__) @@ -115,7 +116,7 @@ class MPClient(EngineCoreClient): **kwargs, ): # Serialization setup. - self.encoder = msgspec.msgpack.Encoder() + self.encoder = PickleEncoder() self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) # ZMQ setup. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index d92e62281038..5f13cbf2e403 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -91,7 +91,10 @@ class Processor: # Make Request for EngineCore. engine_core_request = EngineCoreRequest( request_id, processed_inputs.get("prompt"), - processed_inputs.get("prompt_token_ids"), sampling_params, + processed_inputs.get("prompt_token_ids"), + processed_inputs.get("multi_modal_data"), + processed_inputs.get("multi_modal_placeholders"), + processed_inputs.get("mm_processor_kwargs"), sampling_params, eos_token_id, arrival_time, lora_request) return detokenizer_request, engine_core_request diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py new file mode 100644 index 000000000000..b1cd5c11834f --- /dev/null +++ b/vllm/v1/serial_utils.py @@ -0,0 +1,10 @@ +import pickle + + +class PickleEncoder: + + def encode(self, obj): + return pickle.dumps(obj) + + def decode(self, data): + return pickle.loads(data)