mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[V1][Core] Generic mechanism for handling engine utility (#13060)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
f525c0be8b
commit
caf7ff4456
@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
|
|||||||
]
|
]
|
||||||
for tokenizer_file in tokenizer_files:
|
for tokenizer_file in tokenizer_files:
|
||||||
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
|
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
|
||||||
del_path.unlink()
|
del_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List
|
from contextlib import ExitStack
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@ -14,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core import EngineCore
|
||||||
|
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||||
|
SyncMPClient)
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
|
|||||||
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
engine_core_outputs = await client.get_output_async().outputs
|
engine_core_outputs = (await client.get_output_async()).outputs
|
||||||
|
|
||||||
if len(engine_core_outputs) == 0:
|
if len(engine_core_outputs) == 0:
|
||||||
break
|
break
|
||||||
@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy utility function to monkey-patch into engine core.
|
||||||
|
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
|
||||||
|
print(f"echo util function called: {msg}, {err_msg}")
|
||||||
|
if err_msg is not None:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||||
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||||
@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
|||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
|
# Monkey-patch core engine utility function to test.
|
||||||
|
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
UsageContext.UNKNOWN_CONTEXT)
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
|||||||
|
|
||||||
client.abort_requests([request.request_id])
|
client.abort_requests([request.request_id])
|
||||||
|
|
||||||
|
if multiprocessing_mode:
|
||||||
|
"""Utility method invocation"""
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
core_client: SyncMPClient = client
|
||||||
@pytest.mark.asyncio
|
|
||||||
|
result = core_client._call_utility("echo", "testarg")
|
||||||
|
assert result == "testarg"
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as e_info:
|
||||||
|
core_client._call_utility("echo", None, "help!")
|
||||||
|
|
||||||
|
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="function")
|
||||||
async def test_engine_core_client_asyncio(monkeypatch):
|
async def test_engine_core_client_asyncio(monkeypatch):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m, ExitStack() as after:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
engine_args = EngineArgs(model=MODEL_NAME)
|
# Monkey-patch core engine utility function to test.
|
||||||
|
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
|||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
)
|
)
|
||||||
|
after.callback(client.shutdown)
|
||||||
|
|
||||||
MAX_TOKENS = 20
|
MAX_TOKENS = 20
|
||||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||||
@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
|||||||
else:
|
else:
|
||||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||||
|
"""Utility method invocation"""
|
||||||
|
|
||||||
|
core_client: AsyncMPClient = client
|
||||||
|
|
||||||
|
result = await core_client._call_utility_async("echo", "testarg")
|
||||||
|
assert result == "testarg"
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as e_info:
|
||||||
|
await core_client._call_utility_async("echo", None, "help!")
|
||||||
|
|
||||||
|
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@ -106,6 +106,18 @@ class EngineCoreOutput(
|
|||||||
return self.finish_reason is not None
|
return self.finish_reason is not None
|
||||||
|
|
||||||
|
|
||||||
|
class UtilityOutput(
|
||||||
|
msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
gc=False): # type: ignore[call-arg]
|
||||||
|
|
||||||
|
call_id: int
|
||||||
|
|
||||||
|
# Non-None implies the call failed, result should be None.
|
||||||
|
failure_message: Optional[str] = None
|
||||||
|
result: Any = None
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreOutputs(
|
class EngineCoreOutputs(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
array_like=True, # type: ignore[call-arg]
|
array_like=True, # type: ignore[call-arg]
|
||||||
@ -116,10 +128,12 @@ class EngineCoreOutputs(
|
|||||||
# e.g. columnwise layout
|
# e.g. columnwise layout
|
||||||
|
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
outputs: List[EngineCoreOutput]
|
outputs: List[EngineCoreOutput] = []
|
||||||
scheduler_stats: Optional[SchedulerStats]
|
scheduler_stats: Optional[SchedulerStats] = None
|
||||||
timestamp: float = 0.0
|
timestamp: float = 0.0
|
||||||
|
|
||||||
|
utility_output: Optional[UtilityOutput] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.timestamp == 0.0:
|
if self.timestamp == 0.0:
|
||||||
self.timestamp = time.monotonic()
|
self.timestamp = time.monotonic()
|
||||||
@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
"""
|
"""
|
||||||
ADD = b'\x00'
|
ADD = b'\x00'
|
||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
PROFILE = b'\x02'
|
UTILITY = b'\x02'
|
||||||
RESET_PREFIX_CACHE = b'\x03'
|
|
||||||
ADD_LORA = b'\x04'
|
|
||||||
|
|||||||
@ -5,9 +5,11 @@ import signal
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
from inspect import isclass, signature
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from typing import Any, List, Optional, Tuple, Type
|
from typing import Any, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import msgspec
|
||||||
import psutil
|
import psutil
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
|||||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
|
|||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
elif request_type == EngineCoreRequestType.ABORT:
|
elif request_type == EngineCoreRequestType.ABORT:
|
||||||
self.abort_requests(request)
|
self.abort_requests(request)
|
||||||
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
|
elif request_type == EngineCoreRequestType.UTILITY:
|
||||||
self.reset_prefix_cache()
|
call_id, method_name, args = request
|
||||||
elif request_type == EngineCoreRequestType.PROFILE:
|
output = UtilityOutput(call_id)
|
||||||
self.model_executor.profile(request)
|
try:
|
||||||
elif request_type == EngineCoreRequestType.ADD_LORA:
|
method = getattr(self, method_name)
|
||||||
self.model_executor.add_lora(request)
|
output.result = method(
|
||||||
|
*self._convert_msgspec_args(method, args))
|
||||||
|
except BaseException as e:
|
||||||
|
logger.exception("Invocation of %s method failed", method_name)
|
||||||
|
output.failure_message = (f"Call to {method_name} method"
|
||||||
|
f" failed: {str(e)}")
|
||||||
|
self.output_queue.put_nowait(
|
||||||
|
EngineCoreOutputs(utility_output=output))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_msgspec_args(method, args):
|
||||||
|
"""If a provided arg type doesn't match corresponding target method
|
||||||
|
arg type, try converting to msgspec object."""
|
||||||
|
if not args:
|
||||||
|
return args
|
||||||
|
arg_types = signature(method).parameters.values()
|
||||||
|
assert len(args) <= len(arg_types)
|
||||||
|
return tuple(
|
||||||
|
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
|
||||||
|
and issubclass(p.annotation, msgspec.Struct)
|
||||||
|
and not isinstance(v, p.annotation) else v
|
||||||
|
for v, p in zip(args, arg_types))
|
||||||
|
|
||||||
def process_input_socket(self, input_path: str):
|
def process_input_socket(self, input_path: str):
|
||||||
"""Input socket IO thread."""
|
"""Input socket IO thread."""
|
||||||
|
|
||||||
# Msgpack serialization decoding.
|
# Msgpack serialization decoding.
|
||||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||||
add_lora_decoder = MsgpackDecoder(LoRARequest)
|
|
||||||
generic_decoder = MsgpackDecoder()
|
generic_decoder = MsgpackDecoder()
|
||||||
|
|
||||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||||
@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
|
|||||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||||
|
|
||||||
# Deserialize the request data.
|
# Deserialize the request data.
|
||||||
decoder = None
|
decoder = add_request_decoder if (
|
||||||
if request_type == EngineCoreRequestType.ADD:
|
request_type
|
||||||
decoder = add_request_decoder
|
== EngineCoreRequestType.ADD) else generic_decoder
|
||||||
elif request_type == EngineCoreRequestType.ADD_LORA:
|
|
||||||
decoder = add_lora_decoder
|
|
||||||
else:
|
|
||||||
decoder = generic_decoder
|
|
||||||
|
|
||||||
request = decoder.decode(data_frame.buffer)
|
request = decoder.decode(data_frame.buffer)
|
||||||
|
|
||||||
# Push to input queue for core busy loop.
|
# Push to input queue for core busy loop.
|
||||||
|
|||||||
@ -2,10 +2,14 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
import signal
|
import signal
|
||||||
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Type
|
from concurrent.futures import Future
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -16,7 +20,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||||
make_zmq_socket)
|
make_zmq_socket)
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
@ -24,6 +28,8 @@ from vllm.v1.utils import BackgroundProcHandle
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreClient(ABC):
|
class EngineCoreClient(ABC):
|
||||||
"""
|
"""
|
||||||
@ -204,6 +210,8 @@ class MPClient(EngineCoreClient):
|
|||||||
"log_stats": log_stats,
|
"log_stats": log_stats,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
self.utility_results: Dict[int, AnyFuture] = {}
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Clean up background resources."""
|
"""Clean up background resources."""
|
||||||
if hasattr(self, "proc_handle"):
|
if hasattr(self, "proc_handle"):
|
||||||
@ -212,6 +220,16 @@ class MPClient(EngineCoreClient):
|
|||||||
self._finalizer()
|
self._finalizer()
|
||||||
|
|
||||||
|
|
||||||
|
def _process_utility_output(output: UtilityOutput,
|
||||||
|
utility_results: Dict[int, AnyFuture]):
|
||||||
|
"""Set the result from a utility method in the waiting future"""
|
||||||
|
future = utility_results.pop(output.call_id)
|
||||||
|
if output.failure_message is not None:
|
||||||
|
future.set_exception(Exception(output.failure_message))
|
||||||
|
else:
|
||||||
|
future.set_result(output.result)
|
||||||
|
|
||||||
|
|
||||||
class SyncMPClient(MPClient):
|
class SyncMPClient(MPClient):
|
||||||
"""Synchronous client for multi-proc EngineCore."""
|
"""Synchronous client for multi-proc EngineCore."""
|
||||||
|
|
||||||
@ -224,10 +242,30 @@ class SyncMPClient(MPClient):
|
|||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_output(self) -> EngineCoreOutputs:
|
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||||
|
|
||||||
(frame, ) = self.output_socket.recv_multipart(copy=False)
|
# Ensure that the outputs socket processing thread does not have
|
||||||
return self.decoder.decode(frame.buffer)
|
# a ref to the client which prevents gc.
|
||||||
|
output_socket = self.output_socket
|
||||||
|
decoder = self.decoder
|
||||||
|
utility_results = self.utility_results
|
||||||
|
outputs_queue = self.outputs_queue
|
||||||
|
|
||||||
|
def process_outputs_socket():
|
||||||
|
while True:
|
||||||
|
(frame, ) = output_socket.recv_multipart(copy=False)
|
||||||
|
outputs = decoder.decode(frame.buffer)
|
||||||
|
if outputs.utility_output:
|
||||||
|
_process_utility_output(outputs.utility_output,
|
||||||
|
utility_results)
|
||||||
|
else:
|
||||||
|
outputs_queue.put_nowait(outputs)
|
||||||
|
|
||||||
|
# Process outputs from engine in separate thread.
|
||||||
|
Thread(target=process_outputs_socket, daemon=True).start()
|
||||||
|
|
||||||
|
def get_output(self) -> EngineCoreOutputs:
|
||||||
|
return self.outputs_queue.get()
|
||||||
|
|
||||||
def _send_input(self, request_type: EngineCoreRequestType,
|
def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
request: Any) -> None:
|
request: Any) -> None:
|
||||||
@ -236,6 +274,16 @@ class SyncMPClient(MPClient):
|
|||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
self.input_socket.send_multipart(msg, copy=False)
|
self.input_socket.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
|
def _call_utility(self, method: str, *args) -> Any:
|
||||||
|
call_id = uuid.uuid1().int >> 64
|
||||||
|
future: Future[Any] = Future()
|
||||||
|
self.utility_results[call_id] = future
|
||||||
|
|
||||||
|
self._send_input(EngineCoreRequestType.UTILITY,
|
||||||
|
(call_id, method, args))
|
||||||
|
|
||||||
|
return future.result()
|
||||||
|
|
||||||
def add_request(self, request: EngineCoreRequest) -> None:
|
def add_request(self, request: EngineCoreRequest) -> None:
|
||||||
# NOTE: text prompt is not needed in the core engine as it has been
|
# NOTE: text prompt is not needed in the core engine as it has been
|
||||||
# tokenized.
|
# tokenized.
|
||||||
@ -247,13 +295,13 @@ class SyncMPClient(MPClient):
|
|||||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
self._call_utility("profile", is_start)
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
self._call_utility("reset_prefix_cache")
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
|
self._call_utility("add_lora", lora_request)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
@ -268,24 +316,35 @@ class AsyncMPClient(MPClient):
|
|||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
|
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
|
||||||
self.queue_task: Optional[asyncio.Task] = None
|
self.queue_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def _start_output_queue_task(self):
|
||||||
|
# Perform IO in separate task to parallelize as much as possible.
|
||||||
|
# Avoid task having direct reference back to the client.
|
||||||
|
self.outputs_queue = asyncio.Queue()
|
||||||
|
output_socket = self.output_socket
|
||||||
|
decoder = self.decoder
|
||||||
|
utility_results = self.utility_results
|
||||||
|
outputs_queue = self.outputs_queue
|
||||||
|
|
||||||
|
async def process_outputs_socket():
|
||||||
|
while True:
|
||||||
|
(frame, ) = await output_socket.recv_multipart(copy=False)
|
||||||
|
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
|
||||||
|
if outputs.utility_output:
|
||||||
|
_process_utility_output(outputs.utility_output,
|
||||||
|
utility_results)
|
||||||
|
else:
|
||||||
|
outputs_queue.put_nowait(outputs)
|
||||||
|
|
||||||
|
self.queue_task = asyncio.create_task(process_outputs_socket())
|
||||||
|
|
||||||
async def get_output_async(self) -> EngineCoreOutputs:
|
async def get_output_async(self) -> EngineCoreOutputs:
|
||||||
if self.outputs_queue is None:
|
if self.outputs_queue is None:
|
||||||
# Perform IO in separate task to parallelize as much as possible
|
await self._start_output_queue_task()
|
||||||
self.outputs_queue = asyncio.Queue()
|
assert self.outputs_queue is not None
|
||||||
|
return await self.outputs_queue.get()
|
||||||
async def process_outputs_socket():
|
|
||||||
assert self.outputs_queue is not None
|
|
||||||
while True:
|
|
||||||
(frame, ) = await self.output_socket.recv_multipart(
|
|
||||||
copy=False)
|
|
||||||
self.outputs_queue.put_nowait(frame.buffer)
|
|
||||||
|
|
||||||
self.queue_task = asyncio.create_task(process_outputs_socket())
|
|
||||||
|
|
||||||
return self.decoder.decode(await self.outputs_queue.get())
|
|
||||||
|
|
||||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
request: Any) -> None:
|
request: Any) -> None:
|
||||||
@ -293,6 +352,18 @@ class AsyncMPClient(MPClient):
|
|||||||
msg = (request_type.value, self.encoder.encode(request))
|
msg = (request_type.value, self.encoder.encode(request))
|
||||||
await self.input_socket.send_multipart(msg, copy=False)
|
await self.input_socket.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
|
if self.outputs_queue is None:
|
||||||
|
await self._start_output_queue_task()
|
||||||
|
|
||||||
|
async def _call_utility_async(self, method: str, *args) -> Any:
|
||||||
|
call_id = uuid.uuid1().int >> 64
|
||||||
|
future = asyncio.get_running_loop().create_future()
|
||||||
|
self.utility_results[call_id] = future
|
||||||
|
await self._send_input(EngineCoreRequestType.UTILITY,
|
||||||
|
(call_id, method, args))
|
||||||
|
|
||||||
|
return await future
|
||||||
|
|
||||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||||
# NOTE: text prompt is not needed in the core engine as it has been
|
# NOTE: text prompt is not needed in the core engine as it has been
|
||||||
# tokenized.
|
# tokenized.
|
||||||
@ -304,10 +375,10 @@ class AsyncMPClient(MPClient):
|
|||||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||||
|
|
||||||
async def profile_async(self, is_start: bool = True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
await self._call_utility_async("profile", is_start)
|
||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
await self._call_utility_async("reset_prefix_cache")
|
||||||
|
|
||||||
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||||
await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
|
await self._call_utility_async("add_lora", lora_request)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user