[V1] Use msgpack for core request serialization (#12918)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-02-09 19:35:56 -08:00 committed by GitHub
parent aa0ca5ebb7
commit 67c4637ccf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 95 deletions

View File

@ -1,20 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import enum import enum
from dataclasses import dataclass from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
import msgspec import msgspec
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.outputs import LogprobsLists, LogprobsTensors
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
# These are possible values of RequestOutput.finish_reason, # These are possible values of RequestOutput.finish_reason,
# so form part of the external API. # so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort") FINISH_REASON_STRINGS = ("stop", "length", "abort")
@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum):
return FINISH_REASON_STRINGS[self.value] return FINISH_REASON_STRINGS[self.value]
@dataclass class EngineCoreRequest(
class EngineCoreRequest: msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec # but this object is currently not playing well with msgspec
@ -51,13 +51,13 @@ class EngineCoreRequest:
# Detokenizer, but set to None when it is added to EngineCoreClient. # Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: List[int] prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]] mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[str]] mm_hashes: Optional[List[str]]
mm_placeholders: Optional[List["PlaceholderRange"]] mm_placeholders: Optional[List[PlaceholderRange]]
sampling_params: "SamplingParams" sampling_params: SamplingParams
eos_token_id: Optional[int] eos_token_id: Optional[int]
arrival_time: float arrival_time: float
lora_request: Optional["LoRARequest"] lora_request: Optional[LoRARequest]
class EngineCoreOutput( class EngineCoreOutput(
@ -94,16 +94,6 @@ class EngineCoreOutputs(
scheduler_stats: SchedulerStats scheduler_stats: SchedulerStats
@dataclass
class EngineCoreProfile:
is_start: bool
@dataclass
class EngineCoreResetPrefixCache:
pass
class EngineCoreRequestType(enum.Enum): class EngineCoreRequestType(enum.Enum):
""" """
Request types defined as hex byte strings, so it can be sent over sockets Request types defined as hex byte strings, so it can be sent over sockets
@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
ABORT = b'\x01' ABORT = b'\x01'
PROFILE = b'\x02' PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03' RESET_PREFIX_CACHE = b'\x03'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle
import queue import queue
import signal import signal
import threading import threading
import time import time
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from typing import List, Tuple, Type from typing import Any, List, Tuple, Type
import psutil import psutil
import zmq import zmq
@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequest, EngineCoreRequestType, EngineCoreRequestType)
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the # and to overlap some serialization/deserialization with the
# model forward pass. # model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue. # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket, threading.Thread(target=self.process_input_socket,
args=(input_path, ), args=(input_path, ),
@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
while True: while True:
try: try:
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S) req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
self._handle_client_request(req) self._handle_client_request(*req)
break break
except queue.Empty: except queue.Empty:
logger.debug("EngineCore busy loop waiting.") logger.debug("EngineCore busy loop waiting.")
@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
except BaseException: except BaseException:
raise raise
# 2) Handle any new client requests (Abort or Add). # 2) Handle any new client requests.
while not self.input_queue.empty(): while not self.input_queue.empty():
req = self.input_queue.get_nowait() req = self.input_queue.get_nowait()
self._handle_client_request(req) self._handle_client_request(*req)
# 3) Step the engine core. # 3) Step the engine core.
outputs = self.step() outputs = self.step()
@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
# 5) Put EngineCoreOutputs into the output queue. # 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs) self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: def _handle_client_request(self, request_type: EngineCoreRequestType,
"""Handle EngineCoreRequest or EngineCoreABORT from Client.""" request: Any) -> None:
"""Dispatch request from client."""
if isinstance(request, EngineCoreRequest): if request_type == EngineCoreRequestType.ADD:
self.add_request(request) self.add_request(request)
elif isinstance(request, EngineCoreProfile): elif request_type == EngineCoreRequestType.ABORT:
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)
def process_input_socket(self, input_path: str): def process_input_socket(self, input_path: str):
"""Input socket IO thread.""" """Input socket IO thread."""
# Msgpack serialization decoding. # Msgpack serialization decoding.
decoder_add_req = PickleEncoder() add_request_decoder = MsgpackDecoder(EngineCoreRequest)
decoder_abort_req = PickleEncoder() generic_decoder = MsgpackDecoder()
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True: while True:
# (RequestType, RequestData) # (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False) type_frame, data_frame = socket.recv_multipart(copy=False)
request_type = type_frame.buffer request_type = EngineCoreRequestType(bytes(type_frame.buffer))
request_data = data_frame.buffer
# Deserialize the request data. # Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value: decoder = add_request_decoder if (
request = decoder_add_req.decode(request_data) request_type
elif request_type == EngineCoreRequestType.ABORT.value: == EngineCoreRequestType.ADD) else generic_decoder
request = decoder_abort_req.decode(request_data) request = decoder.decode(data_frame.buffer)
elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait(request) self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str): def process_output_socket(self, output_path: str):
"""Output socket IO thread.""" """Output socket IO thread."""

View File

@ -5,7 +5,7 @@ import os
import signal import signal
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Type from typing import Any, List, Optional, Type
import zmq import zmq
import zmq.asyncio import zmq.asyncio
@ -14,12 +14,11 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket) make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequest, EngineCoreRequestType, EngineCoreRequestType)
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.utils import BackgroundProcHandle from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__) logger = init_logger(__name__)
@ -161,7 +160,7 @@ class MPClient(EngineCoreClient):
signal.signal(signal.SIGUSR1, sigusr1_handler) signal.signal(signal.SIGUSR1, sigusr1_handler)
# Serialization setup. # Serialization setup.
self.encoder = PickleEncoder() self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs) self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup. # ZMQ setup.
@ -220,7 +219,7 @@ class SyncMPClient(MPClient):
return self.decoder.decode(frame.buffer) return self.decoder.decode(frame.buffer)
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None: request: Any) -> None:
# (RequestType, SerializedRequest) # (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
@ -237,12 +236,10 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids) self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE, self._send_input(EngineCoreRequestType.PROFILE, is_start)
EngineCoreProfile(is_start))
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
EngineCoreResetPrefixCache())
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
@ -277,7 +274,7 @@ class AsyncMPClient(MPClient):
return self.decoder.decode(await self.outputs_queue.get()) return self.decoder.decode(await self.outputs_queue.get())
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None: request: Any) -> None:
msg = (request_type.value, self.encoder.encode(request)) msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False) await self.input_socket.send_multipart(msg, copy=False)
@ -293,9 +290,7 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids) await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE, await self._send_input(EngineCoreRequestType.PROFILE, is_start)
EngineCoreProfile(is_start))
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
EngineCoreResetPrefixCache())

View File

@ -1,21 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle import pickle
from typing import Any from typing import Any, Optional
import torch import torch
from msgspec import msgpack from msgspec import msgpack
CUSTOM_TYPE_CODE_PICKLE = 1 CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
class PickleEncoder:
def encode(self, obj: Any):
return pickle.dumps(obj)
def decode(self, data: Any):
return pickle.loads(data)
class MsgpackEncoder: class MsgpackEncoder:
@ -34,8 +26,9 @@ class MsgpackEncoder:
class MsgpackDecoder: class MsgpackDecoder:
"""Decoder with custom torch tensor serialization.""" """Decoder with custom torch tensor serialization."""
def __init__(self, t: Any): def __init__(self, t: Optional[Any] = None):
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook) args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
def decode(self, obj: Any): def decode(self, obj: Any):
return self.decoder.decode(obj) return self.decoder.decode(obj)
@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
# NOTE(rob): it is fastest to use numpy + pickle # NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors. # when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy())) return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
raise NotImplementedError(f"Objects of type {type(obj)} are not supported") return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
def custom_ext_hook(code: int, data: memoryview) -> Any: def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE: if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data)) return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported") raise NotImplementedError(f"Extension type code {code} is not supported")