mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 09:47:03 +08:00
[V0 Deprecation] Remove MQLLMEngine (#25019)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
58d4c705a8
commit
5801e49776
@ -46,7 +46,6 @@ steps:
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
- tests/async_engine
|
||||
- tests/test_inputs.py
|
||||
- tests/test_outputs.py
|
||||
@ -57,7 +56,6 @@ steps:
|
||||
- tests/transformers_utils
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
|
||||
@ -10,7 +10,6 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
@ -18,6 +17,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
@ -82,7 +82,7 @@ def register_mock_resolver():
|
||||
@pytest.fixture
|
||||
def mock_serving_setup():
|
||||
"""Provides a mocked engine and serving completion instance."""
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
|
||||
@ -13,12 +13,12 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -276,7 +276,7 @@ def test_async_serving_chat_init():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
@ -312,7 +312,7 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
@ -355,7 +355,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
@ -410,7 +410,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
@ -467,7 +467,7 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
"repetition_penalty": 1.05
|
||||
}
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
@ -523,7 +523,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.hf_config.model_type = model_type
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@ -1,69 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test that aborting is handled properly."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||
RAISED_ERROR = KeyError
|
||||
RAISED_VALUE = "foo"
|
||||
EXPECTED_TOKENS = 250
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
request_id_to_be_aborted = "request-aborted"
|
||||
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
|
||||
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
|
||||
|
||||
# Requests started before one to be aborted.
|
||||
tasks = []
|
||||
for request_id in request_ids_a:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(client, request_id, EXPECTED_TOKENS)))
|
||||
|
||||
# Aborted.
|
||||
task_aborted = asyncio.create_task(
|
||||
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
|
||||
|
||||
# Requests started after one to be aborted.
|
||||
for request_id in request_ids_b:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(client, request_id, EXPECTED_TOKENS)))
|
||||
|
||||
# Actually abort.
|
||||
await asyncio.sleep(0.5)
|
||||
await client.abort(request_id_to_be_aborted)
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
for task in tasks:
|
||||
count, request_id = await task
|
||||
assert count == EXPECTED_TOKENS, (
|
||||
f"{request_id} generated only {count} tokens")
|
||||
|
||||
# Cancel task (this will hang indefinitely if not).
|
||||
task_aborted.cancel()
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
||||
@ -1,376 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test that various errors are handled properly."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
|
||||
RAISED_ERROR = KeyError
|
||||
RAISED_VALUE = "foo"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Raise error during first forward pass.
|
||||
engine.engine.model_executor.execute_model = Mock(
|
||||
side_effect=RAISED_ERROR(RAISED_VALUE))
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evil_forward(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_forward) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
# Server should be healthy after initial probe.
|
||||
await asyncio.sleep(2.0)
|
||||
await client.check_health()
|
||||
|
||||
# Throws an error that should get ENGINE_DEAD_ERROR.
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
assert client.errored
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
assert client.errored
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
|
||||
ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Raise error during first forward pass.
|
||||
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_health_check(tmp_socket):
|
||||
with RemoteMQLLMEngine(
|
||||
engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_model_executor_health) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Health probe should throw RAISED_ERROR.
|
||||
await asyncio.sleep(15.)
|
||||
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
assert client.errored
|
||||
|
||||
# Generate call should throw ENGINE_DEAD_ERROR
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Raise error during abort call.
|
||||
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_abort(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_abort) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# First check health should work.
|
||||
await client.check_health()
|
||||
|
||||
# Trigger an abort on the client side.
|
||||
# This request ID does not exist, and will cause the engine to error
|
||||
await client.abort(request_id="foo")
|
||||
|
||||
# Future generation requests will now fail
|
||||
# with reference to the original KeyError("foo")
|
||||
with pytest.raises(MQEngineDeadError) as execinfo:
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
assert "KeyError" in repr(execinfo.value)
|
||||
assert client.errored
|
||||
|
||||
# This should raise the original error.
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_error(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_abort) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# First check health should work.
|
||||
await client.check_health()
|
||||
|
||||
# Batch of requests
|
||||
async def do_generate(client):
|
||||
# min_tokens=2048 to keep busy the engine busy
|
||||
# to get enough time to get process a request
|
||||
# that will crash the engine
|
||||
params = SamplingParams(min_tokens=2048, max_tokens=2048)
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=params,
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
|
||||
|
||||
# This request will force a processing batch to raise
|
||||
# an exception and next the engine get errored
|
||||
await client.abort(request_id="foo")
|
||||
|
||||
# The batch of those request failed, then they
|
||||
# should get the same exception as a MQEngineDeadError.
|
||||
errors = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for e in errors:
|
||||
assert isinstance(e, MQEngineDeadError)
|
||||
assert "KeyError" in repr(e)
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
# Invalid request should fail, but not crash the server.
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id="abcd-1",
|
||||
lora_request=LoRARequest(
|
||||
"invalid-lora", 1,
|
||||
"invalid-path")):
|
||||
pass
|
||||
|
||||
# This request should be okay.
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id="abcd-2"):
|
||||
pass
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
|
||||
# When LLMEngine is loaded, it will crash.
|
||||
def mock_init():
|
||||
raise ValueError
|
||||
|
||||
m.setattr(LLMEngine, "__init__", mock_init)
|
||||
|
||||
start = time.perf_counter()
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
|
||||
assert end - start < 100, (
|
||||
"Expected vLLM to gracefully shutdown in <100s "
|
||||
"if there is an error in the startup.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_cuda_init():
|
||||
# it should not crash, when cuda is initialized
|
||||
# in the API server process
|
||||
import torch
|
||||
torch.cuda.init()
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_process_death(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# kill the engine process
|
||||
engine.proc.kill()
|
||||
|
||||
# Generate call should fail
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
# And the health check should show the engine is dead
|
||||
with pytest.raises(RuntimeError, match="Engine process .* died"):
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
|
||||
ipc_path: str):
|
||||
"""Simulate an exception while preparing inputs for the model.
|
||||
In the wild, this could be something like a multimodal input processor
|
||||
failing on invalid image data."""
|
||||
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
runner = engine.engine.model_executor.driver_worker.worker.model_runner
|
||||
|
||||
# Raise error in the model runner when adding a sequence group.
|
||||
# See class ModelInputForGPUBuilder
|
||||
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
|
||||
if seq_group_metadata.request_id.startswith("evil"):
|
||||
raise RAISED_ERROR(RAISED_VALUE)
|
||||
|
||||
runner.builder.per_seq_group_compute_fns.append(raiser)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_inputs(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_input_processing) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Engine should be healthy
|
||||
await client.check_health()
|
||||
|
||||
async def run_failing_request():
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id="evil" + str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
async def run_passing_request():
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
passing_tasks = [
|
||||
asyncio.create_task(run_passing_request()) for _ in range(10)
|
||||
]
|
||||
failing_tasks = [
|
||||
asyncio.create_task(run_failing_request()) for _ in range(10)
|
||||
]
|
||||
await asyncio.gather(*failing_tasks, return_exceptions=True)
|
||||
await asyncio.gather(*passing_tasks)
|
||||
|
||||
# All the bad inputs should have raised
|
||||
for task in failing_tasks:
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
task.result()
|
||||
|
||||
# But all good inputs should have still succeeded
|
||||
for task in passing_tasks:
|
||||
task.result()
|
||||
|
||||
# And the engine should remain healthy
|
||||
assert not client.errored
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
@ -1,59 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
NUM_REQUESTS = 10000
|
||||
|
||||
# Scenarios to test for num generated token.
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(client, request_id, NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
failed_request_id = None
|
||||
tokens = None
|
||||
for task in tasks:
|
||||
num_generated_tokens, request_id = await task
|
||||
if (num_generated_tokens != NUM_EXPECTED_TOKENS
|
||||
and failed_request_id is None):
|
||||
failed_request_id = request_id
|
||||
tokens = num_generated_tokens
|
||||
|
||||
assert failed_request_id is None, (
|
||||
f"{failed_request_id} generated {tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
||||
@ -1,81 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
from typing import Callable, Union
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
|
||||
async def generate(
|
||||
client: MQLLMEngineClient,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]:
|
||||
|
||||
final_output = None
|
||||
count = 0
|
||||
async for out in client.generate(
|
||||
request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=SamplingParams(max_tokens=num_tokens,
|
||||
temperature=0)):
|
||||
|
||||
count += 1
|
||||
final_output = out
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
if return_output:
|
||||
return final_output
|
||||
|
||||
# Confirm we generated all the tokens we expected.
|
||||
return count, request_id
|
||||
|
||||
|
||||
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
class RemoteMQLLMEngine:
|
||||
|
||||
def __init__(self,
|
||||
engine_args: AsyncEngineArgs,
|
||||
ipc_path: str,
|
||||
run_fn: Callable = run_normal) -> None:
|
||||
|
||||
self.engine_args = engine_args
|
||||
self.ipc_path = ipc_path
|
||||
context = multiprocessing.get_context("spawn")
|
||||
self.proc = context.Process(target=run_fn,
|
||||
args=(engine_args, ipc_path))
|
||||
self.proc.start()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.kill()
|
||||
|
||||
async def make_client(self) -> MQLLMEngineClient:
|
||||
engine_config = self.engine_args.create_engine_config()
|
||||
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
|
||||
while True:
|
||||
try:
|
||||
await client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
assert self.proc.is_alive()
|
||||
return client
|
||||
@ -1,145 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Mapping, Optional, Union
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import Device
|
||||
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
IPC_INPUT_EXT = "_input_socket"
|
||||
IPC_OUTPUT_EXT = "_output_socket"
|
||||
IPC_HEALTH_EXT = "_health_socket"
|
||||
IPC_DATA_EXT = "_data_socket"
|
||||
|
||||
|
||||
class MQEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCProcessRequest:
|
||||
prompt: PromptType
|
||||
params: Union[SamplingParams, PoolingParams]
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
priority: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.prompt = prompt
|
||||
self.params = params
|
||||
self.request_id = request_id
|
||||
self.lora_request = lora_request
|
||||
self.trace_headers = trace_headers
|
||||
self.priority = priority
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCError:
|
||||
request_id: Optional[str]
|
||||
is_engine_errored: bool
|
||||
exception: BaseException
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCAbortRequest:
|
||||
request_id: str
|
||||
|
||||
|
||||
class RPCStartupRequest(Enum):
|
||||
IS_SERVER_READY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCStartupResponse:
|
||||
tracing_enabled: bool
|
||||
|
||||
|
||||
class RPCUProfileRequest(Enum):
|
||||
START_PROFILE = 1
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
class RPCResetMultiModalCacheRequest(Enum):
|
||||
RESET = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCResetPrefixCacheRequest:
|
||||
device: Device
|
||||
|
||||
|
||||
class RPCSleepRequest(Enum):
|
||||
SLEEP_LEVEL_1 = 1
|
||||
SLEEP_LEVEL_2 = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCWakeUpRequest:
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCIsSleepingRequest:
|
||||
# Set the default value of request_id to a new UUID
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCIsSleepingResponse:
|
||||
request_id: str
|
||||
is_sleeping: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCLoadAdapterRequest:
|
||||
lora_request: LoRARequest
|
||||
# Set the default value of request_id to a new UUID
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCAdapterLoadedResponse:
|
||||
request_id: str
|
||||
lora_loaded: bool
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||
RPCUProfileRequest, RPCLoadAdapterRequest,
|
||||
RPCResetMultiModalCacheRequest,
|
||||
RPCResetPrefixCacheRequest, RPCSleepRequest,
|
||||
RPCWakeUpRequest, RPCIsSleepingRequest]
|
||||
|
||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
|
||||
RPCIsSleepingResponse, RPCError]
|
||||
|
||||
|
||||
def ENGINE_DEAD_ERROR(
|
||||
error: Optional[BaseException] = None) -> MQEngineDeadError:
|
||||
if error is None:
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
"find the original error")
|
||||
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
f"find the original error: {repr(error)}.")
|
||||
@ -1,643 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
|
||||
Mapping, Optional, Union)
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCAdapterLoadedResponse, RPCError,
|
||||
RPCIsSleepingRequest,
|
||||
RPCIsSleepingResponse,
|
||||
RPCLoadAdapterRequest,
|
||||
RPCProcessRequest,
|
||||
RPCResetMultiModalCacheRequest,
|
||||
RPCResetPrefixCacheRequest,
|
||||
RPCSleepRequest, RPCStartupRequest,
|
||||
RPCStartupResponse,
|
||||
RPCUProfileRequest, RPCWakeUpRequest)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import Device
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MQClientClosedError(Exception):
|
||||
"""Exception class raised when the client is used post-close.
|
||||
|
||||
The client can be closed, which closes the ZMQ context. This normally
|
||||
happens on server shutdown. In some cases, methods like abort and
|
||||
do_log_stats will still be called and then try to open a socket, which
|
||||
causes a ZMQError and creates a huge stack trace.
|
||||
So, we throw this error such that we can suppress it.
|
||||
"""
|
||||
|
||||
|
||||
class MQLLMEngineClient(EngineClient):
|
||||
"""A client wrapper for MQLLMEngine that conforms to the
|
||||
EngineClient protocol.
|
||||
|
||||
MQLLMEngine and MQLLMEngineClient are intended to run in separate
|
||||
processes communicating via zeromq ipc sockets.
|
||||
|
||||
The entrypoint to MQLLMEngineClient is through the generate()
|
||||
method. On generate() MQLLMEngine does three things:
|
||||
- Creates an asyncio output queue
|
||||
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
|
||||
- Pulls RequestOutputs from its queue and yields them
|
||||
|
||||
MQLLMEngine runs two background loops:
|
||||
- output_loop: the output loop pulls List[RequestOutput]
|
||||
from the MQLLMEngine via zmq (each list is the output
|
||||
of one engine_step in the LLMEngine). It then parses
|
||||
the list and pushes individual request_outputs into
|
||||
the corresponding output_queue such that they can be
|
||||
consumed by the .generate() method.
|
||||
- health_loop: the health loop queries the health socket
|
||||
every N seconds, confirming the engine is healthy
|
||||
"""
|
||||
|
||||
def __init__(self, ipc_path: str, engine_config: VllmConfig,
|
||||
engine_pid: int):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Get the configs.
|
||||
self.vllm_config = engine_config
|
||||
self.model_config = engine_config.model_config
|
||||
self.decoding_config = engine_config.decoding_config
|
||||
|
||||
if self.vllm_config.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
|
||||
else:
|
||||
# Create the tokenizer group.
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
lora_config=engine_config.lora_config)
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer)
|
||||
|
||||
# Send RPCGenerateRequest to the MQLLMEngine.
|
||||
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Receive streams of RequestOutput from the MQLLMEngine.
|
||||
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# IPC path for acking heartbeats.
|
||||
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Stream for each individual request.
|
||||
self.output_queues: Dict[str, asyncio.Queue] = {}
|
||||
|
||||
# Loop to handle output of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready so that we can
|
||||
# build the Client in an executor to enable clean shutdown.
|
||||
self.output_loop: Optional[asyncio.Task] = None
|
||||
|
||||
# Loop to check health of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready.
|
||||
self.health_loop: Optional[asyncio.Task] = None
|
||||
self._engine_process = psutil.Process(engine_pid)
|
||||
|
||||
@staticmethod
|
||||
def is_unsupported_config(vllm_config: VllmConfig):
|
||||
# Pipeline parallel not yet supported
|
||||
return vllm_config.parallel_config.pipeline_parallel_size > 1
|
||||
|
||||
@contextmanager
|
||||
def get_data_socket(self) -> Iterator[Socket]:
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
try:
|
||||
socket.connect(self.data_ipc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def run_heartbeat_loop(self, timeout: int):
|
||||
"""Background loop that continually checks to ensure the engine process
|
||||
is still alive.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Check if the engine process is running:
|
||||
if not self._engine_process.is_running() or (
|
||||
self._engine_process.status() == psutil.STATUS_ZOMBIE):
|
||||
# NB: is_running() returns True for zombies
|
||||
self._set_errored(
|
||||
RuntimeError(
|
||||
f"Engine process (pid {self._engine_process.pid}) "
|
||||
"died."))
|
||||
break
|
||||
|
||||
if await self.heartbeat_socket.poll(timeout=timeout):
|
||||
# Heartbeat received- check the message
|
||||
await self._check_success(
|
||||
error_message="Heartbeat failed.",
|
||||
socket=self.heartbeat_socket)
|
||||
|
||||
logger.debug("Heartbeat successful.")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
self._set_errored(
|
||||
RuntimeError(
|
||||
f"Engine process (pid {self._engine_process.pid}) died."))
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
|
||||
async def run_output_handler_loop(self):
|
||||
"""Get RequestOutputs from Engine and stream to Request Queues"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Poll, checking for ENGINE_DEAD
|
||||
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
|
||||
) == 0:
|
||||
logger.debug("Waiting for output from MQLLMEngine.")
|
||||
|
||||
# If errored, alert all running requests.
|
||||
if self.errored:
|
||||
for queue_j in tuple(self.output_queues.values()):
|
||||
queue_j.put_nowait(
|
||||
ENGINE_DEAD_ERROR(self._errored_with))
|
||||
return
|
||||
|
||||
message: Frame = await self.output_socket.recv(copy=False)
|
||||
request_outputs = pickle.loads(message.buffer)
|
||||
|
||||
is_error = isinstance(request_outputs,
|
||||
(BaseException, RPCError))
|
||||
if is_error:
|
||||
if isinstance(request_outputs, RPCError):
|
||||
rpc_error: RPCError = request_outputs
|
||||
request_id = rpc_error.request_id
|
||||
exception = rpc_error.exception
|
||||
is_engine_errored = rpc_error.is_engine_errored
|
||||
else:
|
||||
# MPLLMEngine should always return an RPCError to
|
||||
# the output_socket when an issue arises.
|
||||
# If we are here, we are in a bad state and
|
||||
# should shut down the server.
|
||||
error: BaseException = request_outputs
|
||||
logger.error(
|
||||
"Received Exception %s rather than RPCError from "
|
||||
"MPLLMEngine. This should never happen.", error)
|
||||
request_id = None
|
||||
exception = error
|
||||
is_engine_errored = True
|
||||
|
||||
# Set to error state only on engine critical error
|
||||
# (and record only the first one)
|
||||
if is_engine_errored and not self._errored_with:
|
||||
self._errored_with = exception
|
||||
# If engine is errored, no matter the type of exception
|
||||
# it will no longer be able to receive new requests,
|
||||
# therefore we have to inform that the current
|
||||
# processed requests failed as well. Send back a dead
|
||||
# engine error give this feedback and also give a
|
||||
# 'hint' to the server to shut down next.
|
||||
exception = self.dead_error
|
||||
|
||||
if request_id is None:
|
||||
# If request_id is None, then the engine raised an
|
||||
# exception for a batch, and we may not know the
|
||||
# request that caused it, neither if it was actually
|
||||
# caused by any of them (e.g. CUDA OOM). Therefore we
|
||||
# broadcast the same exception for all requests.
|
||||
for queue_i in tuple(self.output_queues.values()):
|
||||
queue_i.put_nowait(exception)
|
||||
else:
|
||||
queue = self.output_queues.get(request_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(exception)
|
||||
# Put each output into the appropriate queue.
|
||||
elif isinstance(
|
||||
request_outputs,
|
||||
(RPCAdapterLoadedResponse, RPCIsSleepingResponse)):
|
||||
self._add_output(request_outputs)
|
||||
else:
|
||||
for request_output in request_outputs:
|
||||
self._add_output(request_output)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
||||
|
||||
def _add_output(self, request_output: Union[RequestOutput,
|
||||
RPCAdapterLoadedResponse,
|
||||
RPCIsSleepingResponse]):
|
||||
queue = self.output_queues.get(request_output.request_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(request_output)
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the client before it starts sending server requests."""
|
||||
|
||||
# Start output_loop
|
||||
if self.output_loop is None:
|
||||
# only generate once to avoid multiple concurrent output_loops
|
||||
# this will lead to race conditions and wrong orders of tokens
|
||||
# returned by the engine
|
||||
# setup will be called multiple times during the startup of
|
||||
# the engine
|
||||
self.output_loop = asyncio.create_task(
|
||||
self.run_output_handler_loop())
|
||||
|
||||
with self.get_data_socket() as socket:
|
||||
# Wait until server is ready.
|
||||
response = await self._wait_for_server_rpc(socket)
|
||||
|
||||
self.tracing_flag = response.tracing_enabled
|
||||
|
||||
# Start health_loop.
|
||||
if self.health_loop is None:
|
||||
self.health_loop = asyncio.create_task(
|
||||
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
# Close all sockets and terminate the context.
|
||||
self.context.destroy(linger=0)
|
||||
|
||||
# Cancel background tasks.
|
||||
if self.health_loop is not None:
|
||||
self.health_loop.cancel()
|
||||
if self.output_loop is not None:
|
||||
self.output_loop.cancel()
|
||||
|
||||
def _set_errored(self, e: BaseException):
|
||||
logger.exception(repr(e))
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
@staticmethod
|
||||
async def _send_get_data_rpc_request(request: RPCStartupRequest,
|
||||
expected_type: Any,
|
||||
error_message: str,
|
||||
socket: Socket) -> Any:
|
||||
"""Send an RPC request that is expecting data back."""
|
||||
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart((pickle.dumps(request), ), copy=False)
|
||||
|
||||
# Make sure the server responds in time.
|
||||
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||
raise TimeoutError("RPCServer didn't reply within "
|
||||
f"{VLLM_RPC_TIMEOUT} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
frame = await socket.recv(copy=False)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, BaseException):
|
||||
raise data
|
||||
elif not isinstance(data, expected_type):
|
||||
raise ValueError(error_message)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
|
||||
socket: Socket):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
await socket.send_multipart((pickle.dumps(request), ))
|
||||
|
||||
async def _await_ack(self, error_message: str, socket: Socket):
|
||||
"""Await acknowledgement that a request succeeded."""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||
raise TimeoutError("MQLLMEngine didn't reply within "
|
||||
f"{VLLM_RPC_TIMEOUT}ms")
|
||||
|
||||
await self._check_success(error_message, socket)
|
||||
|
||||
@staticmethod
|
||||
async def _check_success(error_message: str, socket: Socket):
|
||||
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
frame = await socket.recv(copy=False)
|
||||
response = pickle.loads(frame.buffer)
|
||||
|
||||
# Raise error if unsuccessful
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
elif (not isinstance(response, str)
|
||||
or response != VLLM_RPC_SUCCESS_STR):
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.input_preprocessor
|
||||
|
||||
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
|
||||
if self.tokenizer is None:
|
||||
return None
|
||||
else:
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
return self.vllm_config
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
return self.decoding_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.tracing_flag
|
||||
|
||||
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
|
||||
"""Wait for the RPCServer to start up."""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
request=RPCStartupRequest.IS_SERVER_READY,
|
||||
expected_type=RPCStartupResponse,
|
||||
error_message="Unable to start RPC Server",
|
||||
socket=socket)
|
||||
|
||||
async def abort(self, request_id: Union[str, Iterable[str]]):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
if not isinstance(request_id, str):
|
||||
raise RuntimeError("Only single-request abort supported in"
|
||||
" deprecated V0")
|
||||
|
||||
with suppress(MQClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Ignore do_log_stats (handled on MQLLMEngine polling)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def check_health(self):
|
||||
"""
|
||||
The check health loop probes the health status of the
|
||||
Engine's health every N seconds and sets _errored_with
|
||||
if the engine is unhealthy.
|
||||
"""
|
||||
if self._errored_with is not None:
|
||||
raise self._errored_with
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return not self.errored
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored_with is not None
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to the LLM. See
|
||||
[`PromptType`][vllm.inputs.PromptType] for more details about
|
||||
the format of each input.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
priority: Priority of the request (lower means earlier handling).
|
||||
Any priority other than 0 will lead to an error if the
|
||||
scheduling policy is not "priority".
|
||||
"""
|
||||
return self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers, priority)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Pooling models are not supported in vLLM V0")
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
# If already dead, error out.
|
||||
if self._errored_with is not None:
|
||||
raise ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
# Ensure the request id is unique among running requests
|
||||
if request_id in self.output_queues:
|
||||
raise ValueError(f"Request {request_id} already exists")
|
||||
|
||||
# 1) Create output queue for this request.
|
||||
queue: asyncio.Queue[Union[RequestOutput,
|
||||
BaseException]] = asyncio.Queue()
|
||||
self.output_queues[request_id] = queue
|
||||
|
||||
try:
|
||||
# 2) Detach logits processors so that they can be pickled
|
||||
# separately (may require cloudpickle which is slower)
|
||||
if params.logits_processors:
|
||||
# Defensive shallow copy
|
||||
params = copy.copy(params)
|
||||
logits_processors = params.logits_processors
|
||||
params.logits_processors = None
|
||||
lp_bytes = cloudpickle.dumps(logits_processors)
|
||||
else:
|
||||
lp_bytes = None
|
||||
|
||||
request_bytes = pickle.dumps(
|
||||
RPCProcessRequest(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
))
|
||||
|
||||
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||
parts = (request_bytes,
|
||||
lp_bytes) if lp_bytes else (request_bytes, )
|
||||
await self.input_socket.send_multipart(parts, copy=False)
|
||||
|
||||
# 4) Stream the RequestOutputs from the output queue. Note
|
||||
# that the output_loop pushes RequestOutput objects to this
|
||||
# queue after pulling them from the zmq socket.
|
||||
finished = False
|
||||
try:
|
||||
while not finished:
|
||||
request_output = await queue.get()
|
||||
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
|
||||
finished = request_output.finished
|
||||
yield request_output
|
||||
finally:
|
||||
# Request was canceled by the client.
|
||||
if not finished and not self.errored:
|
||||
await self.abort(request_id)
|
||||
finally:
|
||||
self.output_queues.pop(request_id)
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Stop profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
"""Reset the multi-modal cache"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCResetMultiModalCacheRequest.RESET,
|
||||
socket=self.input_socket)
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
device: Optional[Device] = None) -> None:
|
||||
"""Reset the prefix cache"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCResetPrefixCacheRequest(device),
|
||||
socket=self.input_socket)
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
"""Sleep the engine for a given level"""
|
||||
return await self._send_one_way_rpc_request(
|
||||
request=RPCSleepRequest(level), socket=self.input_socket)
|
||||
|
||||
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
"""Wake up the engine"""
|
||||
return await self._send_one_way_rpc_request(
|
||||
request=RPCWakeUpRequest(tags), socket=self.input_socket)
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
"""Check whether the engine is sleeping"""
|
||||
request = RPCIsSleepingRequest()
|
||||
|
||||
queue: asyncio.Queue[Union[BaseException,
|
||||
RPCIsSleepingResponse]] = asyncio.Queue()
|
||||
self.output_queues[request.request_id] = queue
|
||||
|
||||
request_bytes = pickle.dumps(request)
|
||||
await self.input_socket.send_multipart((request_bytes, ), copy=False)
|
||||
|
||||
request_output = await queue.get()
|
||||
self.output_queues.pop(request.request_id)
|
||||
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
return request_output.is_sleeping
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
"""Load a new LoRA adapter into the engine for future requests."""
|
||||
# Uses the same I/O as generate requests
|
||||
request = RPCLoadAdapterRequest(lora_request)
|
||||
|
||||
# Create output queue for this request.
|
||||
queue: asyncio.Queue[Union[
|
||||
BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue()
|
||||
self.output_queues[request.request_id] = queue
|
||||
|
||||
# Send the request
|
||||
request_bytes = pickle.dumps(request)
|
||||
await self.input_socket.send_multipart((request_bytes, ), copy=False)
|
||||
|
||||
# Wait for the response
|
||||
request_output = await queue.get()
|
||||
self.output_queues.pop(request.request_id)
|
||||
|
||||
# Raise on error, otherwise happily return None
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
return request_output.lora_loaded
|
||||
@ -1,470 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
from vllm import AsyncEngineArgs, SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCAdapterLoadedResponse, RPCError,
|
||||
RPCIsSleepingRequest,
|
||||
RPCIsSleepingResponse,
|
||||
RPCLoadAdapterRequest,
|
||||
RPCProcessRequest,
|
||||
RPCResetMultiModalCacheRequest,
|
||||
RPCResetPrefixCacheRequest,
|
||||
RPCSleepRequest, RPCStartupRequest,
|
||||
RPCStartupResponse,
|
||||
RPCUProfileRequest, RPCWakeUpRequest)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import deprecate_kwargs
|
||||
from vllm.worker.model_runner_base import InputProcessingError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 10000
|
||||
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
|
||||
|
||||
|
||||
class MQLLMEngine:
|
||||
"""A multiprocessing wrapper for
|
||||
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
|
||||
|
||||
This class is used to wrap the
|
||||
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
|
||||
in concurrent manner. It runs a background loop and uses zeromq to
|
||||
receive new requests and stream outputs incrementally via ipc.
|
||||
|
||||
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
|
||||
process is kicked off when a new RPCProcessRequest is received by the
|
||||
input_socket.
|
||||
|
||||
The self.engine_loop checks the input_socket for new requests,
|
||||
adds them to the LLMEngine if there are any, calls the internal
|
||||
[`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
|
||||
the RequestOutputs back over the output_socket.
|
||||
|
||||
If use_async_sockets is set, the logic associated with reading new
|
||||
requests from the socket and sending data to the socket is passed
|
||||
as a callback to the llm_engine, which calls the logic asynchronously
|
||||
such that the IPC can be overlapped with the GPU.
|
||||
|
||||
Args:
|
||||
ipc_path: Base path for zeromq interprocess messaging
|
||||
use_async_sockets: Whether to make send/recv async with GPU
|
||||
log_requests: Whether to log the requests.
|
||||
*args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
|
||||
**kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ipc_path: str,
|
||||
use_async_sockets: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
**kwargs) -> None:
|
||||
# For MQLLMEngine, we can use cached outputs, since each new request
|
||||
# output is immediately pickled and send over the socket, which frees
|
||||
# the python object to be reused again.
|
||||
kwargs['use_cached_outputs'] = True
|
||||
|
||||
self.engine = LLMEngine(*args, **kwargs)
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.use_async_sockets = use_async_sockets
|
||||
if self.use_async_sockets:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self._async_socket_engine_callback
|
||||
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
# Receive input from the client.
|
||||
self.input_socket = self.ctx.socket(zmq.constants.PULL)
|
||||
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Send output stream back to client.
|
||||
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# Send heartbeats back to client.
|
||||
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
else:
|
||||
return ENGINE_DEAD_ERROR()
|
||||
|
||||
@classmethod
|
||||
@deprecate_kwargs(
|
||||
"disable_log_requests",
|
||||
additional_message=("This argument will have no effect. "
|
||||
"Use `enable_log_requests` instead."),
|
||||
)
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext,
|
||||
enable_log_requests: bool,
|
||||
disable_log_stats: bool,
|
||||
ipc_path: str,
|
||||
disable_log_requests: bool = True, # Deprecated, will be removed
|
||||
) -> "MQLLMEngine":
|
||||
# Setup plugins for each process
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
use_async_sockets = vllm_config.model_config.use_async_output_proc
|
||||
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=LLMEngine._get_executor_cls(vllm_config),
|
||||
ipc_path=ipc_path,
|
||||
usage_context=usage_context,
|
||||
use_async_sockets=use_async_sockets,
|
||||
log_requests=enable_log_requests,
|
||||
log_stats=(not disable_log_stats),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_engine_args(engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, ipc_path: str):
|
||||
"""Creates an MQLLMEngine from the engine arguments."""
|
||||
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
return MQLLMEngine.from_vllm_config(
|
||||
ipc_path=ipc_path,
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
enable_log_requests=engine_args.enable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
try:
|
||||
logger.debug("Starting Startup Loop.")
|
||||
self.run_startup_loop()
|
||||
logger.debug("Starting Engine Loop.")
|
||||
self.run_engine_loop()
|
||||
except Exception as e:
|
||||
logger.exception(repr(e))
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Shutting down MQLLMEngine.")
|
||||
finally:
|
||||
logger.debug("MQLLMEngine is shut down.")
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup zeromq state on shutdown."""
|
||||
# Closes all sockets and destroys context.
|
||||
self.ctx.destroy(linger=0)
|
||||
del self.engine
|
||||
|
||||
@contextmanager
|
||||
def make_data_socket(
|
||||
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
socket = self.ctx.socket(zmq.constants.ROUTER)
|
||||
try:
|
||||
socket.bind(self.data_ipc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
def run_startup_loop(self) -> None:
|
||||
"""Startup loop for sending data from Engine -> Client."""
|
||||
|
||||
with self.make_data_socket() as socket:
|
||||
response: Union[RPCStartupResponse, BaseException]
|
||||
try:
|
||||
identity, message = socket.recv_multipart(copy=False)
|
||||
request: RPCStartupRequest = pickle.loads(message.buffer)
|
||||
|
||||
# Handle the query from the Client.
|
||||
if request == RPCStartupRequest.IS_SERVER_READY:
|
||||
tracing_enabled = self.engine.is_tracing_enabled()
|
||||
response = RPCStartupResponse(
|
||||
tracing_enabled=tracing_enabled)
|
||||
|
||||
except Exception as e:
|
||||
response = e
|
||||
|
||||
socket.send_multipart((identity, pickle.dumps(response)),
|
||||
copy=False)
|
||||
|
||||
def run_engine_loop(self):
|
||||
"""Core busy loop of the LLMEngine."""
|
||||
|
||||
while True:
|
||||
if not self.engine.has_unfinished_requests():
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
# When there's no work, check on engine health and send
|
||||
# health status back to client
|
||||
self._health_check()
|
||||
self.engine.do_log_stats()
|
||||
logger.debug("Waiting for new requests in engine loop.")
|
||||
|
||||
# Handle any input from the client.
|
||||
self.handle_new_input()
|
||||
|
||||
# Engine step.
|
||||
request_outputs = self.engine_step()
|
||||
|
||||
# Send request outputs (if async, done in engine_step callback).
|
||||
if not self.use_async_sockets:
|
||||
self._send_outputs(request_outputs)
|
||||
|
||||
def engine_step(self) -> List[RequestOutput]:
|
||||
"""Engine step wrapper with error handling."""
|
||||
try:
|
||||
return self.engine.step()
|
||||
except SystemExit:
|
||||
raise
|
||||
except InputProcessingError as e:
|
||||
# Special case where we handle an error preparing the inputs for
|
||||
# a single request in the batch
|
||||
rpc_err = RPCError(request_id=e.request_id,
|
||||
is_engine_errored=False,
|
||||
exception=e.__cause__)
|
||||
self._send_outputs(rpc_err)
|
||||
return []
|
||||
except BaseException as e:
|
||||
self._set_errored(e)
|
||||
rpc_err = RPCError(request_id=None,
|
||||
is_engine_errored=True,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
raise e
|
||||
|
||||
def handle_new_input(self):
|
||||
"""Handle new input from the socket"""
|
||||
try:
|
||||
while self.input_socket.poll(timeout=0) != 0:
|
||||
frames = self.input_socket.recv_multipart(copy=False)
|
||||
request = pickle.loads(frames[0].buffer)
|
||||
|
||||
if isinstance(request, RPCProcessRequest):
|
||||
if len(frames) > 1:
|
||||
# Use cloudpickle for logits processors
|
||||
assert isinstance(request.params, SamplingParams)
|
||||
lprocs = cloudpickle.loads(frames[1].buffer)
|
||||
request.params.logits_processors = lprocs
|
||||
self._handle_process_request(request)
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
self._handle_abort_request(request)
|
||||
elif isinstance(request, RPCUProfileRequest):
|
||||
if request == RPCUProfileRequest.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(request, RPCLoadAdapterRequest):
|
||||
self._handle_load_adapter_request(request)
|
||||
elif isinstance(request, RPCResetMultiModalCacheRequest):
|
||||
self.reset_mm_cache()
|
||||
elif isinstance(request, RPCResetPrefixCacheRequest):
|
||||
self.reset_prefix_cache()
|
||||
elif isinstance(request, RPCSleepRequest):
|
||||
self.sleep(request.value)
|
||||
elif isinstance(request, RPCWakeUpRequest):
|
||||
self.wake_up(request.tags)
|
||||
elif isinstance(request, RPCIsSleepingRequest):
|
||||
self._handle_is_sleeping_request(request)
|
||||
else:
|
||||
raise ValueError("Unknown RPCRequest Type: "
|
||||
f"{type(request)}")
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
raise e from None
|
||||
|
||||
def _handle_process_request(self, request: RPCProcessRequest):
|
||||
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
||||
request_id = request.request_id
|
||||
|
||||
if self._errored_with is not None:
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=True,
|
||||
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
self.engine.add_request(request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
priority=request.priority)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
except Exception as e:
|
||||
# We do not set self._errored = True here, since the error
|
||||
# is due to an issue adding this request to the engine,
|
||||
# rather than an issue with the engine itself.
|
||||
logger.debug("Failed to add request %s to engine. %s",
|
||||
request.request_id, e)
|
||||
is_errored = self._errored_with is not None
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=is_errored,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
# Remove request from the engine.
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
def _handle_abort_request(self, request: RPCAbortRequest):
|
||||
self.engine.abort_request(request.request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request.request_id)
|
||||
|
||||
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
|
||||
try:
|
||||
lora_loaded = self.engine.add_lora(request.lora_request)
|
||||
except BaseException as e:
|
||||
# Send back an error if the adater fails to load
|
||||
rpc_err = RPCError(request_id=request.request_id,
|
||||
is_engine_errored=False,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
return
|
||||
# Otherwise, send back the successful load message
|
||||
self._send_outputs(
|
||||
RPCAdapterLoadedResponse(request_id=request.request_id,
|
||||
lora_loaded=lora_loaded))
|
||||
|
||||
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
|
||||
is_sleeping = self.is_sleeping()
|
||||
self._send_outputs(
|
||||
RPCIsSleepingResponse(request_id=request.request_id,
|
||||
is_sleeping=is_sleeping))
|
||||
|
||||
def _health_check(self):
|
||||
# Send unhealthy if engine has already errored
|
||||
if self._errored_with is not None:
|
||||
self._send_unhealthy(self._errored_with)
|
||||
try:
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
|
||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||
"""Send outputs back to the engine client. These can be:
|
||||
- Exceptions
|
||||
- A list of generation outputs
|
||||
- A response from loading a lora adapter
|
||||
"""
|
||||
if outputs:
|
||||
try:
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
# RayTaskError might not pickelable here. We need to unpack the
|
||||
# underlying exception as the real exception in the output.
|
||||
if (isinstance(outputs, RPCError)
|
||||
and isinstance(outputs.exception, RayTaskError)):
|
||||
outputs.exception = outputs.exception.cause
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
output_bytes = pickle.dumps(outputs)
|
||||
self.output_socket.send_multipart((output_bytes, ), copy=False)
|
||||
|
||||
def _send_healthy(self):
|
||||
"""Send HEALTHY message to RPCClient."""
|
||||
if not self.heartbeat_socket.closed:
|
||||
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
||||
|
||||
def _send_unhealthy(self, error: BaseException):
|
||||
"""Send UNHEALTHY message to RPCClient."""
|
||||
if not self.heartbeat_socket.closed:
|
||||
error_bytes = pickle.dumps(error)
|
||||
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
|
||||
|
||||
def _async_socket_engine_callback(self,
|
||||
request_outputs: REQUEST_OUTPUTS_T):
|
||||
"""Callback used by engine to make socket handling async with GPU."""
|
||||
self._send_outputs(request_outputs)
|
||||
self.handle_new_input()
|
||||
|
||||
def _set_errored(self, e: BaseException):
|
||||
"""Log and set errored status if this is the first issue."""
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.engine.start_profile()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.engine.stop_profile()
|
||||
|
||||
def reset_mm_cache(self) -> bool:
|
||||
return self.engine.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.engine.reset_prefix_cache()
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self.engine.sleep(level)
|
||||
|
||||
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
||||
self.engine.wake_up(tags)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine.is_sleeping()
|
||||
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||
|
||||
|
||||
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
||||
ipc_path: str, disable_log_stats: bool,
|
||||
enable_log_requests: bool, engine_alive):
|
||||
try:
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
engine = MQLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_stats=disable_log_stats,
|
||||
enable_log_requests=enable_log_requests,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine.start()
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
engine_alive.value = False
|
||||
raise e from None
|
||||
@ -12,7 +12,6 @@ from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||
@ -156,7 +155,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
@app.exception_handler(MQEngineDeadError)
|
||||
@app.exception_handler(EngineDeadError)
|
||||
@app.exception_handler(EngineGenerateError)
|
||||
async def runtime_exception_handler(request: Request, __):
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import gc
|
||||
import importlib
|
||||
import inspect
|
||||
@ -17,7 +16,6 @@ import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Callable, Optional
|
||||
|
||||
@ -42,8 +40,6 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
@ -102,13 +98,10 @@ from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||
log_non_default_args, with_cancellation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs,
|
||||
get_open_zmq_ipc_path, is_valid_ipv6_address,
|
||||
set_ulimit)
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
@ -237,8 +230,7 @@ async def build_async_engine_client_from_engine_args(
|
||||
async_llm.shutdown()
|
||||
|
||||
# V0 AsyncLLM.
|
||||
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
|
||||
or disable_frontend_multiprocessing):
|
||||
else:
|
||||
|
||||
engine_client: Optional[EngineClient] = None
|
||||
try:
|
||||
@ -252,96 +244,6 @@ async def build_async_engine_client_from_engine_args(
|
||||
if engine_client and hasattr(engine_client, "shutdown"):
|
||||
engine_client.shutdown()
|
||||
|
||||
# V0MQLLMEngine.
|
||||
else:
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
global prometheus_multiproc_dir
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ[
|
||||
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
else:
|
||||
logger.warning(
|
||||
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
# Select random path for IPC.
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.debug("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
# the current process might have CUDA context,
|
||||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
# The Process can raise an exception during startup, which may
|
||||
# not actually result in an exitcode being reported. As a result
|
||||
# we use a shared variable to communicate the information.
|
||||
engine_alive = multiprocessing.Value('b', True, lock=False)
|
||||
engine_process = context.Process(
|
||||
target=run_mp_engine,
|
||||
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
|
||||
engine_args.disable_log_stats,
|
||||
engine_args.enable_log_requests, engine_alive))
|
||||
engine_process.start()
|
||||
engine_pid = engine_process.pid
|
||||
assert engine_pid is not None, "Engine process failed to start."
|
||||
logger.info("Started engine process with PID %d", engine_pid)
|
||||
|
||||
def _cleanup_ipc_path():
|
||||
socket_path = ipc_path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
# Ensure we clean up the local IPC socket file on exit.
|
||||
atexit.register(_cleanup_ipc_path)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
|
||||
engine_pid)
|
||||
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_client)
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await mq_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if (not engine_process.is_alive()
|
||||
or not engine_alive.value):
|
||||
raise RuntimeError(
|
||||
"Engine process failed to start. See stack "
|
||||
"trace for the root cause.") from None
|
||||
|
||||
yield mq_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
mq_engine_client.close()
|
||||
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
if engine_process.exitcode is None:
|
||||
# Kill if taking longer than 5 seconds to stop
|
||||
engine_process.kill()
|
||||
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import multiprocess
|
||||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
|
||||
@ -191,7 +191,7 @@ class RocmPlatform(Platform):
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
from vllm.attention.backends.rocm_aiter_mla import (
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled)
|
||||
|
||||
if selected_backend is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user