mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:56:09 +08:00
[V1] Use msgpack for core request serialization (#12918)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
aa0ca5ebb7
commit
67c4637ccf
@ -1,20 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||
@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineCoreRequest:
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
|
||||
# but this object is currently not playing well with msgspec
|
||||
@ -51,13 +51,13 @@ class EngineCoreRequest:
|
||||
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
|
||||
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[List[str]]
|
||||
mm_placeholders: Optional[List["PlaceholderRange"]]
|
||||
sampling_params: "SamplingParams"
|
||||
mm_placeholders: Optional[List[PlaceholderRange]]
|
||||
sampling_params: SamplingParams
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional["LoRARequest"]
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
@ -94,16 +94,6 @@ class EngineCoreOutputs(
|
||||
scheduler_stats: SchedulerStats
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineCoreProfile:
|
||||
is_start: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineCoreResetPrefixCache:
|
||||
pass
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
|
||||
ABORT = b'\x01'
|
||||
PROFILE = b'\x02'
|
||||
RESET_PREFIX_CACHE = b'\x03'
|
||||
|
||||
|
||||
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
|
||||
EngineCoreResetPrefixCache, List[str]]
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import List, Tuple, Type
|
||||
from typing import Any, List, Tuple, Type
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
|
||||
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
|
||||
Any]] = queue.Queue()
|
||||
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, ),
|
||||
@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
|
||||
while True:
|
||||
try:
|
||||
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
|
||||
self._handle_client_request(req)
|
||||
self._handle_client_request(*req)
|
||||
break
|
||||
except queue.Empty:
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
|
||||
except BaseException:
|
||||
raise
|
||||
|
||||
# 2) Handle any new client requests (Abort or Add).
|
||||
# 2) Handle any new client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(req)
|
||||
self._handle_client_request(*req)
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = self.step()
|
||||
@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
|
||||
# 5) Put EngineCoreOutputs into the output queue.
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
|
||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if isinstance(request, EngineCoreRequest):
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
self.add_request(request)
|
||||
elif isinstance(request, EngineCoreProfile):
|
||||
self.model_executor.profile(request.is_start)
|
||||
elif isinstance(request, EngineCoreResetPrefixCache):
|
||||
self.reset_prefix_cache()
|
||||
else:
|
||||
# TODO: make an EngineCoreAbort wrapper
|
||||
assert isinstance(request, list)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
|
||||
self.reset_prefix_cache()
|
||||
elif request_type == EngineCoreRequestType.PROFILE:
|
||||
self.model_executor.profile(request)
|
||||
|
||||
def process_input_socket(self, input_path: str):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
decoder_add_req = PickleEncoder()
|
||||
decoder_abort_req = PickleEncoder()
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
request_type = type_frame.buffer
|
||||
request_data = data_frame.buffer
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
if request_type == EngineCoreRequestType.ADD.value:
|
||||
request = decoder_add_req.decode(request_data)
|
||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
||||
request = decoder_abort_req.decode(request_data)
|
||||
elif request_type in (
|
||||
EngineCoreRequestType.PROFILE.value,
|
||||
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
|
||||
request = pickle.loads(request_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frame.buffer)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait(request)
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_socket(self, output_path: str):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
@ -5,7 +5,7 @@ import os
|
||||
import signal
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -14,12 +14,11 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||
make_zmq_socket)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.utils import BackgroundProcHandle
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -161,7 +160,7 @@ class MPClient(EngineCoreClient):
|
||||
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
||||
|
||||
# Serialization setup.
|
||||
self.encoder = PickleEncoder()
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
@ -220,7 +219,7 @@ class SyncMPClient(MPClient):
|
||||
return self.decoder.decode(frame.buffer)
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: EngineCoreRequestUnion) -> None:
|
||||
request: Any) -> None:
|
||||
|
||||
# (RequestType, SerializedRequest)
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
@ -237,12 +236,10 @@ class SyncMPClient(MPClient):
|
||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
self._send_input(EngineCoreRequestType.PROFILE,
|
||||
EngineCoreProfile(is_start))
|
||||
self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
||||
EngineCoreResetPrefixCache())
|
||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
@ -277,7 +274,7 @@ class AsyncMPClient(MPClient):
|
||||
return self.decoder.decode(await self.outputs_queue.get())
|
||||
|
||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: EngineCoreRequestUnion) -> None:
|
||||
request: Any) -> None:
|
||||
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
await self.input_socket.send_multipart(msg, copy=False)
|
||||
@ -293,9 +290,7 @@ class AsyncMPClient(MPClient):
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
await self._send_input(EngineCoreRequestType.PROFILE,
|
||||
EngineCoreProfile(is_start))
|
||||
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
||||
EngineCoreResetPrefixCache())
|
||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||
|
||||
@ -1,21 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from msgspec import msgpack
|
||||
|
||||
CUSTOM_TYPE_CODE_PICKLE = 1
|
||||
|
||||
|
||||
class PickleEncoder:
|
||||
|
||||
def encode(self, obj: Any):
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def decode(self, data: Any):
|
||||
return pickle.loads(data)
|
||||
CUSTOM_TYPE_TENSOR = 1
|
||||
CUSTOM_TYPE_PICKLE = 2
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
@ -34,8 +26,9 @@ class MsgpackEncoder:
|
||||
class MsgpackDecoder:
|
||||
"""Decoder with custom torch tensor serialization."""
|
||||
|
||||
def __init__(self, t: Any):
|
||||
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
|
||||
def __init__(self, t: Optional[Any] = None):
|
||||
args = () if t is None else (t, )
|
||||
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
|
||||
|
||||
def decode(self, obj: Any):
|
||||
return self.decoder.decode(obj)
|
||||
@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
|
||||
# NOTE(rob): it is fastest to use numpy + pickle
|
||||
# when serializing torch tensors.
|
||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
||||
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
|
||||
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
||||
|
||||
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
|
||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
||||
|
||||
|
||||
def custom_ext_hook(code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_CODE_PICKLE:
|
||||
if code == CUSTOM_TYPE_TENSOR:
|
||||
return torch.from_numpy(pickle.loads(data))
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
|
||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user