mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 03:25:01 +08:00
[V1] Use msgpack for core request serialization (#12918)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
aa0ca5ebb7
commit
67c4637ccf
@ -1,19 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
from typing import List, Optional, Union
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
|
||||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
|
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||||
|
|
||||||
# These are possible values of RequestOutput.finish_reason,
|
# These are possible values of RequestOutput.finish_reason,
|
||||||
# so form part of the external API.
|
# so form part of the external API.
|
||||||
@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum):
|
|||||||
return FINISH_REASON_STRINGS[self.value]
|
return FINISH_REASON_STRINGS[self.value]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class EngineCoreRequest(
|
||||||
class EngineCoreRequest:
|
msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False): # type: ignore[call-arg]
|
||||||
|
|
||||||
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
|
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
|
||||||
# but this object is currently not playing well with msgspec
|
# but this object is currently not playing well with msgspec
|
||||||
@ -51,13 +51,13 @@ class EngineCoreRequest:
|
|||||||
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
||||||
prompt: Optional[str]
|
prompt: Optional[str]
|
||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
|
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
||||||
mm_hashes: Optional[List[str]]
|
mm_hashes: Optional[List[str]]
|
||||||
mm_placeholders: Optional[List["PlaceholderRange"]]
|
mm_placeholders: Optional[List[PlaceholderRange]]
|
||||||
sampling_params: "SamplingParams"
|
sampling_params: SamplingParams
|
||||||
eos_token_id: Optional[int]
|
eos_token_id: Optional[int]
|
||||||
arrival_time: float
|
arrival_time: float
|
||||||
lora_request: Optional["LoRARequest"]
|
lora_request: Optional[LoRARequest]
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreOutput(
|
class EngineCoreOutput(
|
||||||
@ -94,16 +94,6 @@ class EngineCoreOutputs(
|
|||||||
scheduler_stats: SchedulerStats
|
scheduler_stats: SchedulerStats
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EngineCoreProfile:
|
|
||||||
is_start: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EngineCoreResetPrefixCache:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreRequestType(enum.Enum):
|
class EngineCoreRequestType(enum.Enum):
|
||||||
"""
|
"""
|
||||||
Request types defined as hex byte strings, so it can be sent over sockets
|
Request types defined as hex byte strings, so it can be sent over sockets
|
||||||
@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
PROFILE = b'\x02'
|
PROFILE = b'\x02'
|
||||||
RESET_PREFIX_CACHE = b'\x03'
|
RESET_PREFIX_CACHE = b'\x03'
|
||||||
|
|
||||||
|
|
||||||
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
|
|
||||||
EngineCoreResetPrefixCache, List[str]]
|
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pickle
|
|
||||||
import queue
|
import queue
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from typing import List, Tuple, Type
|
from typing import Any, List, Tuple, Type
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import zmq
|
import zmq
|
||||||
@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
|
|||||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||||
from vllm.v1.core.scheduler import Scheduler
|
from vllm.v1.core.scheduler import Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequest, EngineCoreRequestType,
|
EngineCoreRequestType)
|
||||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
# and to overlap some serialization/deserialization with the
|
# and to overlap some serialization/deserialization with the
|
||||||
# model forward pass.
|
# model forward pass.
|
||||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||||
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
|
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
|
||||||
|
Any]] = queue.Queue()
|
||||||
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||||
threading.Thread(target=self.process_input_socket,
|
threading.Thread(target=self.process_input_socket,
|
||||||
args=(input_path, ),
|
args=(input_path, ),
|
||||||
@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
|
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
|
||||||
self._handle_client_request(req)
|
self._handle_client_request(*req)
|
||||||
break
|
break
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
logger.debug("EngineCore busy loop waiting.")
|
logger.debug("EngineCore busy loop waiting.")
|
||||||
@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
|
|||||||
except BaseException:
|
except BaseException:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 2) Handle any new client requests (Abort or Add).
|
# 2) Handle any new client requests.
|
||||||
while not self.input_queue.empty():
|
while not self.input_queue.empty():
|
||||||
req = self.input_queue.get_nowait()
|
req = self.input_queue.get_nowait()
|
||||||
self._handle_client_request(req)
|
self._handle_client_request(*req)
|
||||||
|
|
||||||
# 3) Step the engine core.
|
# 3) Step the engine core.
|
||||||
outputs = self.step()
|
outputs = self.step()
|
||||||
@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
|
|||||||
# 5) Put EngineCoreOutputs into the output queue.
|
# 5) Put EngineCoreOutputs into the output queue.
|
||||||
self.output_queue.put_nowait(outputs)
|
self.output_queue.put_nowait(outputs)
|
||||||
|
|
||||||
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
|
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
request: Any) -> None:
|
||||||
|
"""Dispatch request from client."""
|
||||||
|
|
||||||
if isinstance(request, EngineCoreRequest):
|
if request_type == EngineCoreRequestType.ADD:
|
||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
elif isinstance(request, EngineCoreProfile):
|
elif request_type == EngineCoreRequestType.ABORT:
|
||||||
self.model_executor.profile(request.is_start)
|
|
||||||
elif isinstance(request, EngineCoreResetPrefixCache):
|
|
||||||
self.reset_prefix_cache()
|
|
||||||
else:
|
|
||||||
# TODO: make an EngineCoreAbort wrapper
|
|
||||||
assert isinstance(request, list)
|
|
||||||
self.abort_requests(request)
|
self.abort_requests(request)
|
||||||
|
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
|
||||||
|
self.reset_prefix_cache()
|
||||||
|
elif request_type == EngineCoreRequestType.PROFILE:
|
||||||
|
self.model_executor.profile(request)
|
||||||
|
|
||||||
def process_input_socket(self, input_path: str):
|
def process_input_socket(self, input_path: str):
|
||||||
"""Input socket IO thread."""
|
"""Input socket IO thread."""
|
||||||
|
|
||||||
# Msgpack serialization decoding.
|
# Msgpack serialization decoding.
|
||||||
decoder_add_req = PickleEncoder()
|
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||||
decoder_abort_req = PickleEncoder()
|
generic_decoder = MsgpackDecoder()
|
||||||
|
|
||||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||||
while True:
|
while True:
|
||||||
# (RequestType, RequestData)
|
# (RequestType, RequestData)
|
||||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||||
request_type = type_frame.buffer
|
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||||
request_data = data_frame.buffer
|
|
||||||
|
|
||||||
# Deserialize the request data.
|
# Deserialize the request data.
|
||||||
if request_type == EngineCoreRequestType.ADD.value:
|
decoder = add_request_decoder if (
|
||||||
request = decoder_add_req.decode(request_data)
|
request_type
|
||||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
== EngineCoreRequestType.ADD) else generic_decoder
|
||||||
request = decoder_abort_req.decode(request_data)
|
request = decoder.decode(data_frame.buffer)
|
||||||
elif request_type in (
|
|
||||||
EngineCoreRequestType.PROFILE.value,
|
|
||||||
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
|
|
||||||
request = pickle.loads(request_data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
|
||||||
|
|
||||||
# Push to input queue for core busy loop.
|
# Push to input queue for core busy loop.
|
||||||
self.input_queue.put_nowait(request)
|
self.input_queue.put_nowait((request_type, request))
|
||||||
|
|
||||||
def process_output_socket(self, output_path: str):
|
def process_output_socket(self, output_path: str):
|
||||||
"""Output socket IO thread."""
|
"""Output socket IO thread."""
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Type
|
from typing import Any, List, Optional, Type
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -14,12 +14,11 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||||
make_zmq_socket)
|
make_zmq_socket)
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequest, EngineCoreRequestType,
|
EngineCoreRequestType)
|
||||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
from vllm.v1.utils import BackgroundProcHandle
|
from vllm.v1.utils import BackgroundProcHandle
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -161,7 +160,7 @@ class MPClient(EngineCoreClient):
|
|||||||
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
||||||
|
|
||||||
# Serialization setup.
|
# Serialization setup.
|
||||||
self.encoder = PickleEncoder()
|
self.encoder = MsgpackEncoder()
|
||||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||||
|
|
||||||
# ZMQ setup.
|
# ZMQ setup.
|
||||||
@ -220,7 +219,7 @@ class SyncMPClient(MPClient):
|
|||||||
return self.decoder.decode(frame.buffer)
|
return self.decoder.decode(frame.buffer)
|
||||||
|
|
||||||
def _send_input(self, request_type: EngineCoreRequestType,
|
def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
request: EngineCoreRequestUnion) -> None:
|
request: Any) -> None:
|
||||||
|
|
||||||
# (RequestType, SerializedRequest)
|
# (RequestType, SerializedRequest)
|
||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
@ -237,12 +236,10 @@ class SyncMPClient(MPClient):
|
|||||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self._send_input(EngineCoreRequestType.PROFILE,
|
self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
||||||
EngineCoreProfile(is_start))
|
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||||
EngineCoreResetPrefixCache())
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
@ -277,7 +274,7 @@ class AsyncMPClient(MPClient):
|
|||||||
return self.decoder.decode(await self.outputs_queue.get())
|
return self.decoder.decode(await self.outputs_queue.get())
|
||||||
|
|
||||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
request: EngineCoreRequestUnion) -> None:
|
request: Any) -> None:
|
||||||
|
|
||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
await self.input_socket.send_multipart(msg, copy=False)
|
await self.input_socket.send_multipart(msg, copy=False)
|
||||||
@ -293,9 +290,7 @@ class AsyncMPClient(MPClient):
|
|||||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
async def profile_async(self, is_start: bool = True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.PROFILE,
|
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
||||||
EngineCoreProfile(is_start))
|
|
||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
|
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||||
EngineCoreResetPrefixCache())
|
|
||||||
|
|||||||
@ -1,21 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from msgspec import msgpack
|
from msgspec import msgpack
|
||||||
|
|
||||||
CUSTOM_TYPE_CODE_PICKLE = 1
|
CUSTOM_TYPE_TENSOR = 1
|
||||||
|
CUSTOM_TYPE_PICKLE = 2
|
||||||
|
|
||||||
class PickleEncoder:
|
|
||||||
|
|
||||||
def encode(self, obj: Any):
|
|
||||||
return pickle.dumps(obj)
|
|
||||||
|
|
||||||
def decode(self, data: Any):
|
|
||||||
return pickle.loads(data)
|
|
||||||
|
|
||||||
|
|
||||||
class MsgpackEncoder:
|
class MsgpackEncoder:
|
||||||
@ -34,8 +26,9 @@ class MsgpackEncoder:
|
|||||||
class MsgpackDecoder:
|
class MsgpackDecoder:
|
||||||
"""Decoder with custom torch tensor serialization."""
|
"""Decoder with custom torch tensor serialization."""
|
||||||
|
|
||||||
def __init__(self, t: Any):
|
def __init__(self, t: Optional[Any] = None):
|
||||||
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
|
args = () if t is None else (t, )
|
||||||
|
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
|
||||||
|
|
||||||
def decode(self, obj: Any):
|
def decode(self, obj: Any):
|
||||||
return self.decoder.decode(obj)
|
return self.decoder.decode(obj)
|
||||||
@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
|
|||||||
# NOTE(rob): it is fastest to use numpy + pickle
|
# NOTE(rob): it is fastest to use numpy + pickle
|
||||||
# when serializing torch tensors.
|
# when serializing torch tensors.
|
||||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
||||||
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
|
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
||||||
|
|
||||||
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
|
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
||||||
|
|
||||||
|
|
||||||
def custom_ext_hook(code: int, data: memoryview) -> Any:
|
def custom_ext_hook(code: int, data: memoryview) -> Any:
|
||||||
if code == CUSTOM_TYPE_CODE_PICKLE:
|
if code == CUSTOM_TYPE_TENSOR:
|
||||||
return torch.from_numpy(pickle.loads(data))
|
return torch.from_numpy(pickle.loads(data))
|
||||||
|
if code == CUSTOM_TYPE_PICKLE:
|
||||||
|
return pickle.loads(data)
|
||||||
|
|
||||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user