[V0 Deprecation] Remove MQLLMEngine (#25019)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-16 21:29:27 -07:00 committed by GitHub
parent 58d4c705a8
commit 5801e49776
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 12 additions and 1969 deletions

View File

@ -46,7 +46,6 @@ steps:
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/mq_llm_engine
- tests/async_engine
- tests/test_inputs.py
- tests/test_outputs.py
@ -57,7 +56,6 @@ steps:
- tests/transformers_utils
commands:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py

View File

@ -10,7 +10,6 @@ from unittest.mock import MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
@ -18,6 +17,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
@ -82,7 +82,7 @@ def register_mock_resolver():
@pytest.fixture
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

View File

@ -13,12 +13,12 @@ import pytest
import pytest_asyncio
from vllm.config.multimodal import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
from ...utils import RemoteOpenAIServer
@ -276,7 +276,7 @@ def test_async_serving_chat_init():
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@ -312,7 +312,7 @@ async def test_serving_chat_returns_correct_model_name():
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@ -355,7 +355,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
}
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@ -410,7 +410,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
}
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@ -467,7 +467,7 @@ async def test_serving_chat_could_load_correct_generation_config():
"repetition_penalty": 1.05
}
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@ -523,7 +523,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_model_config = MockModelConfig()
mock_model_config.hf_config.model_type = model_type
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

View File

@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')

View File

@ -1,69 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test that aborting is handled properly."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
EXPECTED_TOKENS = 250
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_id_to_be_aborted = "request-aborted"
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
# Requests started before one to be aborted.
tasks = []
for request_id in request_ids_a:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Aborted.
task_aborted = asyncio.create_task(
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
# Requests started after one to be aborted.
for request_id in request_ids_b:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Actually abort.
await asyncio.sleep(0.5)
await client.abort(request_id_to_be_aborted)
# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
count, request_id = await task
assert count == EXPECTED_TOKENS, (
f"{request_id} generated only {count} tokens")
# Cancel task (this will hang indefinitely if not).
task_aborted.cancel()
# Shutdown.
client.close()

View File

@ -1,376 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test that various errors are handled properly."""
import asyncio
import tempfile
import time
import uuid
from unittest.mock import Mock
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroupMetadata
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.execute_model = Mock(
side_effect=RAISED_ERROR(RAISED_VALUE))
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_evil_forward(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_forward) as engine:
client = await engine.make_client()
# Server should be healthy after initial probe.
await asyncio.sleep(2.0)
await client.check_health()
# Throws an error that should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=str(uuid.uuid4())):
pass
assert client.errored
await asyncio.sleep(1.0)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Shutdown.
client.close()
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_health_check(tmp_socket):
with RemoteMQLLMEngine(
engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_model_executor_health) as engine:
client = await engine.make_client()
assert client.is_running
# Health probe should throw RAISED_ERROR.
await asyncio.sleep(15.)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=str(uuid.uuid4())):
pass
client.close()
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during abort call.
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# First check health should work.
await client.check_health()
# Trigger an abort on the client side.
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")
# Future generation requests will now fail
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=str(uuid.uuid4())):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
client.close()
@pytest.mark.asyncio
async def test_batch_error(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# First check health should work.
await client.check_health()
# Batch of requests
async def do_generate(client):
# min_tokens=2048 to keep busy the engine busy
# to get enough time to get process a request
# that will crash the engine
params = SamplingParams(min_tokens=2048, max_tokens=2048)
async for _ in client.generate(prompt="Hello my name is",
sampling_params=params,
request_id=str(uuid.uuid4())):
pass
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
# This request will force a processing batch to raise
# an exception and next the engine get errored
await client.abort(request_id="foo")
# The batch of those request failed, then they
# should get the same exception as a MQEngineDeadError.
errors = await asyncio.gather(*tasks, return_exceptions=True)
for e in errors:
assert isinstance(e, MQEngineDeadError)
assert "KeyError" in repr(e)
client.close()
@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
"invalid-lora", 1,
"invalid-path")):
pass
# This request should be okay.
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
# Shutdown.
client.close()
@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# When LLMEngine is loaded, it will crash.
def mock_init():
raise ValueError
m.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 100, (
"Expected vLLM to gracefully shutdown in <100s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass
@pytest.mark.asyncio
async def test_engine_process_death(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
assert client.is_running
# kill the engine process
engine.proc.kill()
# Generate call should fail
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=str(uuid.uuid4())):
pass
# And the health check should show the engine is dead
with pytest.raises(RuntimeError, match="Engine process .* died"):
await client.check_health()
client.close()
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
ipc_path: str):
"""Simulate an exception while preparing inputs for the model.
In the wild, this could be something like a multimodal input processor
failing on invalid image data."""
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
runner = engine.engine.model_executor.driver_worker.worker.model_runner
# Raise error in the model runner when adding a sequence group.
# See class ModelInputForGPUBuilder
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
if seq_group_metadata.request_id.startswith("evil"):
raise RAISED_ERROR(RAISED_VALUE)
runner.builder.per_seq_group_compute_fns.append(raiser)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_inputs(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_input_processing) as engine:
client = await engine.make_client()
assert client.is_running
# Engine should be healthy
await client.check_health()
async def run_failing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id="evil" + str(uuid.uuid4())):
pass
async def run_passing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=str(uuid.uuid4())):
pass
passing_tasks = [
asyncio.create_task(run_passing_request()) for _ in range(10)
]
failing_tasks = [
asyncio.create_task(run_failing_request()) for _ in range(10)
]
await asyncio.gather(*failing_tasks, return_exceptions=True)
await asyncio.gather(*passing_tasks)
# All the bad inputs should have raised
for task in failing_tasks:
with pytest.raises(RAISED_ERROR):
task.result()
# But all good inputs should have still succeeded
for task in passing_tasks:
task.result()
# And the engine should remain healthy
assert not client.errored
await client.check_health()
client.close()

View File

@ -1,59 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
NUM_EXPECTED_TOKENS = 10
NUM_REQUESTS = 10000
# Scenarios to test for num generated token.
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_load(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(client, request_id, NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
failed_request_id = None
tokens = None
for task in tasks:
num_generated_tokens, request_id = await task
if (num_generated_tokens != NUM_EXPECTED_TOKENS
and failed_request_id is None):
failed_request_id = request_id
tokens = num_generated_tokens
assert failed_request_id is None, (
f"{failed_request_id} generated {tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
# Shutdown.
client.close()

View File

@ -1,81 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import multiprocessing
from typing import Callable, Union
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
async def generate(
client: MQLLMEngineClient,
request_id: str,
num_tokens: int,
return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]:
final_output = None
count = 0
async for out in client.generate(
request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):
count += 1
final_output = out
await asyncio.sleep(0.)
if return_output:
return final_output
# Confirm we generated all the tokens we expected.
return count, request_id
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Run engine.
engine.start()
class RemoteMQLLMEngine:
def __init__(self,
engine_args: AsyncEngineArgs,
ipc_path: str,
run_fn: Callable = run_normal) -> None:
self.engine_args = engine_args
self.ipc_path = ipc_path
context = multiprocessing.get_context("spawn")
self.proc = context.Process(target=run_fn,
args=(engine_args, ipc_path))
self.proc.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.kill()
async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
while True:
try:
await client.setup()
break
except TimeoutError:
assert self.proc.is_alive()
return client

View File

@ -1,145 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Mapping, Optional, Union
from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Device
VLLM_RPC_SUCCESS_STR = "SUCCESS"
IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
class MQEngineDeadError(RuntimeError):
pass
@dataclass
class RPCProcessRequest:
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
priority: int = 0
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None:
super().__init__()
self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.priority = priority
@dataclass
class RPCError:
request_id: Optional[str]
is_engine_errored: bool
exception: BaseException
@dataclass
class RPCAbortRequest:
request_id: str
class RPCStartupRequest(Enum):
IS_SERVER_READY = 1
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
class RPCUProfileRequest(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
class RPCResetMultiModalCacheRequest(Enum):
RESET = 1
@dataclass
class RPCResetPrefixCacheRequest:
device: Device
class RPCSleepRequest(Enum):
SLEEP_LEVEL_1 = 1
SLEEP_LEVEL_2 = 2
@dataclass
class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass
class RPCIsSleepingRequest:
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@dataclass
class RPCIsSleepingResponse:
request_id: str
is_sleeping: bool
@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@dataclass
class RPCAdapterLoadedResponse:
request_id: str
lora_loaded: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest, RPCIsSleepingRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCIsSleepingResponse, RPCError]
def ENGINE_DEAD_ERROR(
error: Optional[BaseException] = None) -> MQEngineDeadError:
if error is None:
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error")
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")

View File

@ -1,643 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
Mapping, Optional, Union)
import cloudpickle
import psutil
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import Device
logger = init_logger(__name__)
class MQClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class MQLLMEngineClient(EngineClient):
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
MQLLMEngine and MQLLMEngineClient are intended to run in separate
processes communicating via zeromq ipc sockets.
The entrypoint to MQLLMEngineClient is through the generate()
method. On generate() MQLLMEngine does three things:
- Creates an asyncio output queue
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
- Pulls RequestOutputs from its queue and yields them
MQLLMEngine runs two background loops:
- output_loop: the output loop pulls List[RequestOutput]
from the MQLLMEngine via zmq (each list is the output
of one engine_step in the LLMEngine). It then parses
the list and pushes individual request_outputs into
the corresponding output_queue such that they can be
consumed by the .generate() method.
- health_loop: the health loop queries the health socket
every N seconds, confirming the engine is healthy
"""
def __init__(self, ipc_path: str, engine_config: VllmConfig,
engine_pid: int):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
# Get the configs.
self.vllm_config = engine_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
if self.vllm_config.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Create the tokenizer group.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)
# Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
# Receive streams of RequestOutput from the MQLLMEngine.
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Stream for each individual request.
self.output_queues: Dict[str, asyncio.Queue] = {}
# Loop to handle output of the LLMEngine periodically.
# Started after the MQLLMEngine is ready so that we can
# build the Client in an executor to enable clean shutdown.
self.output_loop: Optional[asyncio.Task] = None
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
self._engine_process = psutil.Process(engine_pid)
@staticmethod
def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported
return vllm_config.parallel_config.pipeline_parallel_size > 1
@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually checks to ensure the engine process
is still alive.
"""
try:
while True:
# Check if the engine process is running:
if not self._engine_process.is_running() or (
self._engine_process.status() == psutil.STATUS_ZOMBIE):
# NB: is_running() returns True for zombies
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) "
"died."))
break
if await self.heartbeat_socket.poll(timeout=timeout):
# Heartbeat received- check the message
await self._check_success(
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)
logger.debug("Heartbeat successful.")
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
except psutil.NoSuchProcess:
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) died."))
except Exception as e:
self._set_errored(e)
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
try:
while True:
# Poll, checking for ENGINE_DEAD
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
) == 0:
logger.debug("Waiting for output from MQLLMEngine.")
# If errored, alert all running requests.
if self.errored:
for queue_j in tuple(self.output_queues.values()):
queue_j.put_nowait(
ENGINE_DEAD_ERROR(self._errored_with))
return
message: Frame = await self.output_socket.recv(copy=False)
request_outputs = pickle.loads(message.buffer)
is_error = isinstance(request_outputs,
(BaseException, RPCError))
if is_error:
if isinstance(request_outputs, RPCError):
rpc_error: RPCError = request_outputs
request_id = rpc_error.request_id
exception = rpc_error.exception
is_engine_errored = rpc_error.is_engine_errored
else:
# MPLLMEngine should always return an RPCError to
# the output_socket when an issue arises.
# If we are here, we are in a bad state and
# should shut down the server.
error: BaseException = request_outputs
logger.error(
"Received Exception %s rather than RPCError from "
"MPLLMEngine. This should never happen.", error)
request_id = None
exception = error
is_engine_errored = True
# Set to error state only on engine critical error
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shut down next.
exception = self.dead_error
if request_id is None:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else:
queue = self.output_queues.get(request_id)
if queue is not None:
queue.put_nowait(exception)
# Put each output into the appropriate queue.
elif isinstance(
request_outputs,
(RPCAdapterLoadedResponse, RPCIsSleepingResponse)):
self._add_output(request_outputs)
else:
for request_output in request_outputs:
self._add_output(request_output)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")
def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse,
RPCIsSleepingResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
async def setup(self):
"""Set up the client before it starts sending server requests."""
# Start output_loop
if self.output_loop is None:
# only generate once to avoid multiple concurrent output_loops
# this will lead to race conditions and wrong orders of tokens
# returned by the engine
# setup will be called multiple times during the startup of
# the engine
self.output_loop = asyncio.create_task(
self.run_output_handler_loop())
with self.get_data_socket() as socket:
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
self.tracing_flag = response.tracing_enabled
# Start health_loop.
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
self.context.destroy(linger=0)
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()
def _set_errored(self, e: BaseException):
logger.exception(repr(e))
if self._errored_with is None:
self._errored_with = e
@staticmethod
async def _send_get_data_rpc_request(request: RPCStartupRequest,
expected_type: Any,
error_message: str,
socket: Socket) -> Any:
"""Send an RPC request that is expecting data back."""
# Ping RPCServer with a request.
await socket.send_multipart((pickle.dumps(request), ), copy=False)
# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")
# Await the data from the Server.
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)
if isinstance(data, BaseException):
raise data
elif not isinstance(data, expected_type):
raise ValueError(error_message)
return data
@staticmethod
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
socket: Socket):
"""Send one-way RPC request to trigger an action."""
if socket.closed:
raise MQClientClosedError()
await socket.send_multipart((pickle.dumps(request), ))
async def _await_ack(self, error_message: str, socket: Socket):
"""Await acknowledgement that a request succeeded."""
if socket.closed:
raise MQClientClosedError()
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("MQLLMEngine didn't reply within "
f"{VLLM_RPC_TIMEOUT}ms")
await self._check_success(error_message, socket)
@staticmethod
async def _check_success(error_message: str, socket: Socket):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
if socket.closed:
raise MQClientClosedError()
frame = await socket.recv(copy=False)
response = pickle.loads(frame.buffer)
# Raise error if unsuccessful
if isinstance(response, BaseException):
raise response
elif (not isinstance(response, str)
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.input_preprocessor
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
if self.tokenizer is None:
return None
else:
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
"""Wait for the RPCServer to start up."""
return await self._send_get_data_rpc_request(
request=RPCStartupRequest.IS_SERVER_READY,
expected_type=RPCStartupResponse,
error_message="Unable to start RPC Server",
socket=socket)
async def abort(self, request_id: Union[str, Iterable[str]]):
"""Send an ABORT_REQUEST signal to the RPC Server"""
if not isinstance(request_id, str):
raise RuntimeError("Only single-request abort supported in"
" deprecated V0")
with suppress(MQClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass
async def check_health(self):
"""
The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with
if the engine is unhealthy.
"""
if self._errored_with is not None:
raise self._errored_with
@property
def is_running(self) -> bool:
return not self.errored
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self._errored_with is not None
@property
def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with)
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, priority)
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
async def _process_request(
self,
prompt: PromptType,
params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)
# Ensure the request id is unique among running requests
if request_id in self.output_queues:
raise ValueError(f"Request {request_id} already exists")
# 1) Create output queue for this request.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if params.logits_processors:
# Defensive shallow copy
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished = False
try:
while not finished:
request_output = await queue.get()
if isinstance(request_output, BaseException):
raise request_output
finished = request_output.finished
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self.errored:
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""
await self._send_one_way_rpc_request(
request=RPCResetMultiModalCacheRequest.RESET,
socket=self.input_socket)
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket)
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine for a given level"""
return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
request = RPCIsSleepingRequest()
queue: asyncio.Queue[Union[BaseException,
RPCIsSleepingResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
request_output = await queue.get()
self.output_queues.pop(request.request_id)
if isinstance(request_output, BaseException):
raise request_output
return request_output.is_sleeping
async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this request.
queue: asyncio.Queue[Union[
BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
# Send the request
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
# Wait for the response
request_output = await queue.get()
self.output_queues.pop(request.request_id)
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
return request_output.lora_loaded

View File

@ -1,470 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import VllmConfig
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs
from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine:
"""A multiprocessing wrapper for
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurrent manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
process is kicked off when a new RPCProcessRequest is received by the
input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
[`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
the RequestOutputs back over the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
"""
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
kwargs['use_cached_outputs'] = True
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
if self.use_async_sockets:
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
# Send output stream back to client.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
@classmethod
@deprecate_kwargs(
"disable_log_requests",
additional_message=("This argument will have no effect. "
"Use `enable_log_requests` instead."),
)
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext,
enable_log_requests: bool,
disable_log_stats: bool,
ipc_path: str,
disable_log_requests: bool = True, # Deprecated, will be removed
) -> "MQLLMEngine":
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
use_async_sockets = vllm_config.model_config.use_async_output_proc
return cls(
vllm_config=vllm_config,
executor_class=LLMEngine._get_executor_cls(vllm_config),
ipc_path=ipc_path,
usage_context=usage_context,
use_async_sockets=use_async_sockets,
log_requests=enable_log_requests,
log_stats=(not disable_log_stats),
)
@staticmethod
def from_engine_args(engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
vllm_config = engine_args.create_engine_config(usage_context)
return MQLLMEngine.from_vllm_config(
ipc_path=ipc_path,
vllm_config=vllm_config,
usage_context=usage_context,
enable_log_requests=engine_args.enable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
def start(self):
try:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
logger.exception(repr(e))
except KeyboardInterrupt:
logger.debug("Shutting down MQLLMEngine.")
finally:
logger.debug("MQLLMEngine is shut down.")
self.cleanup()
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self.ctx.destroy(linger=0)
del self.engine
@contextmanager
def make_data_socket(
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.ROUTER)
try:
socket.bind(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
def run_startup_loop(self) -> None:
"""Startup loop for sending data from Engine -> Client."""
with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, message = socket.recv_multipart(copy=False)
request: RPCStartupRequest = pickle.loads(message.buffer)
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
except Exception as e:
response = e
socket.send_multipart((identity, pickle.dumps(response)),
copy=False)
def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
while True:
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
# Handle any input from the client.
self.handle_new_input()
# Engine step.
request_outputs = self.engine_step()
# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
self._send_outputs(request_outputs)
def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""
try:
return self.engine.step()
except SystemExit:
raise
except InputProcessingError as e:
# Special case where we handle an error preparing the inputs for
# a single request in the batch
rpc_err = RPCError(request_id=e.request_id,
is_engine_errored=False,
exception=e.__cause__)
self._send_outputs(rpc_err)
return []
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,
is_engine_errored=True,
exception=e)
self._send_outputs(rpc_err)
raise e
def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetMultiModalCacheRequest):
self.reset_mm_cache()
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request)
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
raise e from None
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
rpc_err = RPCError(request_id=request_id,
is_engine_errored=True,
exception=ENGINE_DEAD_ERROR(self._errored_with))
self._send_outputs(rpc_err)
try:
self.engine.add_request(request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
priority=request.priority)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
except Exception as e:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
logger.debug("Failed to add request %s to engine. %s",
request.request_id, e)
is_errored = self._errored_with is not None
rpc_err = RPCError(request_id=request_id,
is_engine_errored=is_errored,
exception=e)
self._send_outputs(rpc_err)
# Remove request from the engine.
self.engine.abort_request(request_id)
def _handle_abort_request(self, request: RPCAbortRequest):
self.engine.abort_request(request.request_id)
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
lora_loaded = self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id,
lora_loaded=lora_loaded))
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
is_sleeping = self.is_sleeping()
self._send_outputs(
RPCIsSleepingResponse(request_id=request.request_id,
is_sleeping=is_sleeping))
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError
# RayTaskError might not pickelable here. We need to unpack the
# underlying exception as the real exception in the output.
if (isinstance(outputs, RPCError)
and isinstance(outputs.exception, RayTaskError)):
outputs.exception = outputs.exception.cause
except ImportError:
pass
output_bytes = pickle.dumps(outputs)
self.output_socket.send_multipart((output_bytes, ), copy=False)
def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
"""Callback used by engine to make socket handling async with GPU."""
self._send_outputs(request_outputs)
self.handle_new_input()
def _set_errored(self, e: BaseException):
"""Log and set errored status if this is the first issue."""
if self._errored_with is None:
self._errored_with = e
def start_profile(self) -> None:
self.engine.start_profile()
def stop_profile(self) -> None:
self.engine.stop_profile()
def reset_mm_cache(self) -> bool:
return self.engine.reset_mm_cache()
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
enable_log_requests: bool, engine_alive):
try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_stats=disable_log_stats,
enable_log_requests=enable_log_requests,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e from None

View File

@ -12,7 +12,6 @@ from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
@ -156,7 +155,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
@app.exception_handler(RuntimeError)
@app.exception_handler(AsyncEngineDeadError)
@app.exception_handler(MQEngineDeadError)
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import atexit
import gc
import importlib
import inspect
@ -17,7 +16,6 @@ import uuid
from argparse import Namespace
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import Annotated, Any, Callable, Optional
@ -42,8 +40,6 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (load_chat_template,
resolve_hf_chat_template,
@ -102,13 +98,10 @@ from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
log_non_default_args, with_cancellation)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs,
get_open_zmq_ipc_path, is_valid_ipv6_address,
set_ulimit)
is_valid_ipv6_address, set_ulimit)
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION
@ -237,8 +230,7 @@ async def build_async_engine_client_from_engine_args(
async_llm.shutdown()
# V0 AsyncLLM.
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
or disable_frontend_multiprocessing):
else:
engine_client: Optional[EngineClient] = None
try:
@ -252,96 +244,6 @@ async def build_async_engine_client_from_engine_args(
if engine_client and hasattr(engine_client, "shutdown"):
engine_client.shutdown()
# V0MQLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
ipc_path = get_open_zmq_ipc_path()
logger.debug("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
# The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(
target=run_mp_engine,
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
engine_args.disable_log_stats,
engine_args.enable_log_requests, engine_alive))
engine_process.start()
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start."
logger.info("Started engine process with PID %d", engine_pid)
def _cleanup_ipc_path():
socket_path = ipc_path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
# Ensure we clean up the local IPC socket file on exit.
atexit.register(_cleanup_ipc_path)
# Build RPCClient, which conforms to EngineClient Protocol.
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client)
try:
while True:
try:
await mq_engine_client.setup()
break
except TimeoutError:
if (not engine_process.is_alive()
or not engine_alive.value):
raise RuntimeError(
"Engine process failed to start. See stack "
"trace for the root cause.") from None
yield mq_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
engine_process.terminate()
# Close all open connections to the backend
mq_engine_client.close()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(engine_process.pid)
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()

View File

@ -191,7 +191,7 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
from vllm.attention.backends.rocm_aiter_mla import (
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled)
if selected_backend is None: