Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2025-03-22 13:25:05 -04:00
parent 9f7fb5ec84
commit 2ceb7bc534
2 changed files with 11 additions and 4 deletions

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import msgspec import msgspec
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Dict, Mapping, Optional from typing import Dict, Mapping, Optional
import uvicorn import uvicorn
@ -57,6 +58,12 @@ class PDResponse(msgspec.Struct,
stop_reason: Optional[str] = None stop_reason: Optional[str] = None
logprobs = None # TODO logprobs = None # TODO
@asynccontextmanager
def build_pd_engine_client(prefill_addr: str, decode_addr: str,
connector_addr: str):
engine = PDEngine(prefill_addr, decode_addr, connector_addr)
yield engine
engine.shutdown()
class PDEngine: class PDEngine:
""" """
@ -68,6 +75,7 @@ class PDEngine:
* TODO: support more than just text input. * TODO: support more than just text input.
* TODO: move under vllm/v1/engine one past prototype. * TODO: move under vllm/v1/engine one past prototype.
""" """
def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str): def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str):
# Request queues. # Request queues.
self.queues: Dict[str, asyncio.Queue] = {} self.queues: Dict[str, asyncio.Queue] = {}

View File

@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.disaggregated.engine import PDEngine from vllm.entrypoints.disaggregated.engine import build_pd_engine_client
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@ -136,16 +136,15 @@ async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]: args: Namespace) -> AsyncIterator[EngineClient]:
# Case 1: We are running a P/D Connector. # Case 1: We are running a P/D Connector.
# The Engines may be running on another node.
if hasattr(args, "connector_addr"): if hasattr(args, "connector_addr"):
async with PDEngine( async with build_pd_engine_client(
prefill_addr=args.prefill_addr, prefill_addr=args.prefill_addr,
decode_addr=args.decode_addr, decode_addr=args.decode_addr,
connector_addr=args.connector_addr) as engine: connector_addr=args.connector_addr) as engine:
yield engine yield engine
engine.shutdown() engine.shutdown()
# Case 2: We are running an actual Engine from this process. # Case 2: We are running a normal instance of vLLM.
else: else:
# Context manager to handle engine_client lifecycle # Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit