[V1] Use msgpack for core request serialization (#12918)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-02-09 19:35:56 -08:00 committed by GitHub
parent aa0ca5ebb7
commit 67c4637ccf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 95 deletions

View File

@ -1,20 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
import enum
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union
import msgspec
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum):
return FINISH_REASON_STRINGS[self.value]
@dataclass
class EngineCoreRequest:
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
@ -51,13 +51,13 @@ class EngineCoreRequest:
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[List["PlaceholderRange"]]
sampling_params: "SamplingParams"
mm_placeholders: Optional[List[PlaceholderRange]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional["LoRARequest"]
lora_request: Optional[LoRARequest]
class EngineCoreOutput(
@ -94,16 +94,6 @@ class EngineCoreOutputs(
scheduler_stats: SchedulerStats
@dataclass
class EngineCoreProfile:
is_start: bool
@dataclass
class EngineCoreResetPrefixCache:
pass
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
ABORT = b'\x01'
PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
import queue
import signal
import threading
import time
from multiprocessing.connection import Connection
from typing import List, Tuple, Type
from typing import Any, List, Tuple, Type
import psutil
import zmq
@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
while True:
try:
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
self._handle_client_request(req)
self._handle_client_request(*req)
break
except queue.Empty:
logger.debug("EngineCore busy loop waiting.")
@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
except BaseException:
raise
# 2) Handle any new client requests (Abort or Add).
# 2) Handle any new client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(req)
self._handle_client_request(*req)
# 3) Step the engine core.
outputs = self.step()
@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""
if isinstance(request, EngineCoreRequest):
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)
def process_input_socket(self, input_path: str):
"""Input socket IO thread."""
# Msgpack serialization decoding.
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
request_type = type_frame.buffer
request_data = data_frame.buffer
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value:
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
# Push to input queue for core busy loop.
self.input_queue.put_nowait(request)
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str):
"""Output socket IO thread."""

View File

@ -5,7 +5,7 @@ import os
import signal
import weakref
from abc import ABC, abstractmethod
from typing import List, Optional, Type
from typing import Any, List, Optional, Type
import zmq
import zmq.asyncio
@ -14,12 +14,11 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
@ -161,7 +160,7 @@ class MPClient(EngineCoreClient):
signal.signal(signal.SIGUSR1, sigusr1_handler)
# Serialization setup.
self.encoder = PickleEncoder()
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
@ -220,7 +219,7 @@ class SyncMPClient(MPClient):
return self.decoder.decode(frame.buffer)
def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
request: Any) -> None:
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
@ -237,12 +236,10 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
self._send_input(EngineCoreRequestType.PROFILE, is_start)
def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
class AsyncMPClient(MPClient):
@ -277,7 +274,7 @@ class AsyncMPClient(MPClient):
return self.decoder.decode(await self.outputs_queue.get())
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
request: Any) -> None:
msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
@ -293,9 +290,7 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)

View File

@ -1,21 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
from typing import Any
from typing import Any, Optional
import torch
from msgspec import msgpack
CUSTOM_TYPE_CODE_PICKLE = 1
class PickleEncoder:
def encode(self, obj: Any):
return pickle.dumps(obj)
def decode(self, data: Any):
return pickle.loads(data)
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
class MsgpackEncoder:
@ -34,8 +26,9 @@ class MsgpackEncoder:
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
def __init__(self, t: Any):
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")