diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index b05ef3cc8c74..30e1185019d9 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,20 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import msgspec +from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange +from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors -if TYPE_CHECKING: - from vllm.lora.request import LoRARequest - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.inputs import PlaceholderRange - from vllm.sampling_params import SamplingParams - # These are possible values of RequestOutput.finish_reason, # so form part of the external API. FINISH_REASON_STRINGS = ("stop", "length", "abort") @@ -39,8 +36,11 @@ class FinishReason(enum.IntEnum): return FINISH_REASON_STRINGS[self.value] -@dataclass -class EngineCoreRequest: +class EngineCoreRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # but this object is currently not playing well with msgspec @@ -51,13 +51,13 @@ class EngineCoreRequest: # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] prompt_token_ids: List[int] - mm_inputs: Optional[List[Optional["MultiModalKwargs"]]] + mm_inputs: Optional[List[Optional[MultiModalKwargs]]] mm_hashes: Optional[List[str]] - mm_placeholders: Optional[List["PlaceholderRange"]] - sampling_params: "SamplingParams" + mm_placeholders: Optional[List[PlaceholderRange]] + sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float - lora_request: Optional["LoRARequest"] + lora_request: Optional[LoRARequest] class EngineCoreOutput( @@ -94,16 +94,6 @@ class EngineCoreOutputs( scheduler_stats: SchedulerStats -@dataclass -class EngineCoreProfile: - is_start: bool - - -@dataclass -class EngineCoreResetPrefixCache: - pass - - class EngineCoreRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets @@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum): ABORT = b'\x01' PROFILE = b'\x02' RESET_PREFIX_CACHE = b'\x03' - - -EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, - EngineCoreResetPrefixCache, List[str]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f3d40aa1e9cb..c90667ba0331 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import pickle import queue import signal import threading import time from multiprocessing.connection import Connection -from typing import List, Tuple, Type +from typing import Any, List, Tuple, Type import psutil import zmq @@ -19,13 +18,12 @@ from vllm.transformers_utils.config import ( from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore): # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() + self.input_queue: queue.Queue[Tuple[EngineCoreRequestType, + Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, args=(input_path, ), @@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore): while True: try: req = self.input_queue.get(timeout=POLLING_TIMEOUT_S) - self._handle_client_request(req) + self._handle_client_request(*req) break except queue.Empty: logger.debug("EngineCore busy loop waiting.") @@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore): except BaseException: raise - # 2) Handle any new client requests (Abort or Add). + # 2) Handle any new client requests. while not self.input_queue.empty(): req = self.input_queue.get_nowait() - self._handle_client_request(req) + self._handle_client_request(*req) # 3) Step the engine core. outputs = self.step() @@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore): # 5) Put EngineCoreOutputs into the output queue. self.output_queue.put_nowait(outputs) - def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: - """Handle EngineCoreRequest or EngineCoreABORT from Client.""" + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + """Dispatch request from client.""" - if isinstance(request, EngineCoreRequest): + if request_type == EngineCoreRequestType.ADD: self.add_request(request) - elif isinstance(request, EngineCoreProfile): - self.model_executor.profile(request.is_start) - elif isinstance(request, EngineCoreResetPrefixCache): - self.reset_prefix_cache() - else: - # TODO: make an EngineCoreAbort wrapper - assert isinstance(request, list) + elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) + elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE: + self.reset_prefix_cache() + elif request_type == EngineCoreRequestType.PROFILE: + self.model_executor.profile(request) def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. - decoder_add_req = PickleEncoder() - decoder_abort_req = PickleEncoder() + add_request_decoder = MsgpackDecoder(EngineCoreRequest) + generic_decoder = MsgpackDecoder() with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) - request_type = type_frame.buffer - request_data = data_frame.buffer + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. - if request_type == EngineCoreRequestType.ADD.value: - request = decoder_add_req.decode(request_data) - elif request_type == EngineCoreRequestType.ABORT.value: - request = decoder_abort_req.decode(request_data) - elif request_type in ( - EngineCoreRequestType.PROFILE.value, - EngineCoreRequestType.RESET_PREFIX_CACHE.value): - request = pickle.loads(request_data) - else: - raise ValueError(f"Unknown RequestType: {request_type}") + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frame.buffer) # Push to input queue for core busy loop. - self.input_queue.put_nowait(request) + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str): """Output socket IO thread.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index cdc63acdb746..2d7d6b42ced5 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -5,7 +5,7 @@ import os import signal import weakref from abc import ABC, abstractmethod -from typing import List, Optional, Type +from typing import Any, List, Optional, Type import zmq import zmq.asyncio @@ -14,12 +14,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, make_zmq_socket) -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, - EngineCoreRequest, EngineCoreRequestType, - EngineCoreRequestUnion, EngineCoreResetPrefixCache) +from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor -from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.utils import BackgroundProcHandle logger = init_logger(__name__) @@ -161,7 +160,7 @@ class MPClient(EngineCoreClient): signal.signal(signal.SIGUSR1, sigusr1_handler) # Serialization setup. - self.encoder = PickleEncoder() + self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. @@ -220,7 +219,7 @@ class SyncMPClient(MPClient): return self.decoder.decode(frame.buffer) def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: + request: Any) -> None: # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) @@ -237,12 +236,10 @@ class SyncMPClient(MPClient): self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: - self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) + self._send_input(EngineCoreRequestType.PROFILE, is_start) def reset_prefix_cache(self) -> None: - self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, - EngineCoreResetPrefixCache()) + self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) class AsyncMPClient(MPClient): @@ -277,7 +274,7 @@ class AsyncMPClient(MPClient): return self.decoder.decode(await self.outputs_queue.get()) async def _send_input(self, request_type: EngineCoreRequestType, - request: EngineCoreRequestUnion) -> None: + request: Any) -> None: msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) @@ -293,9 +290,7 @@ class AsyncMPClient(MPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: - await self._send_input(EngineCoreRequestType.PROFILE, - EngineCoreProfile(is_start)) + await self._send_input(EngineCoreRequestType.PROFILE, is_start) async def reset_prefix_cache_async(self) -> None: - await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, - EngineCoreResetPrefixCache()) + await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a7fba65e7c95..3f000abcde0d 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,21 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import pickle -from typing import Any +from typing import Any, Optional import torch from msgspec import msgpack -CUSTOM_TYPE_CODE_PICKLE = 1 - - -class PickleEncoder: - - def encode(self, obj: Any): - return pickle.dumps(obj) - - def decode(self, data: Any): - return pickle.loads(data) +CUSTOM_TYPE_TENSOR = 1 +CUSTOM_TYPE_PICKLE = 2 class MsgpackEncoder: @@ -34,8 +26,9 @@ class MsgpackEncoder: class MsgpackDecoder: """Decoder with custom torch tensor serialization.""" - def __init__(self, t: Any): - self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook) + def __init__(self, t: Optional[Any] = None): + args = () if t is None else (t, ) + self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook) def decode(self, obj: Any): return self.decoder.decode(obj) @@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any: # NOTE(rob): it is fastest to use numpy + pickle # when serializing torch tensors. # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 - return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy())) + return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) - raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) def custom_ext_hook(code: int, data: memoryview) -> Any: - if code == CUSTOM_TYPE_CODE_PICKLE: + if code == CUSTOM_TYPE_TENSOR: return torch.from_numpy(pickle.loads(data)) + if code == CUSTOM_TYPE_PICKLE: + return pickle.loads(data) raise NotImplementedError(f"Extension type code {code} is not supported")