From 2ceb7bc534e26076480e46f98df32245cfa643d5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 22 Mar 2025 13:25:05 -0400 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/entrypoints/disaggregated/engine.py | 8 ++++++++ vllm/entrypoints/openai/api_server.py | 7 +++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/disaggregated/engine.py b/vllm/entrypoints/disaggregated/engine.py index b3992234ecc25..162eb77de1a5c 100644 --- a/vllm/entrypoints/disaggregated/engine.py +++ b/vllm/entrypoints/disaggregated/engine.py @@ -3,6 +3,7 @@ import asyncio import msgspec from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from typing import Dict, Mapping, Optional import uvicorn @@ -57,6 +58,12 @@ class PDResponse(msgspec.Struct, stop_reason: Optional[str] = None logprobs = None # TODO +@asynccontextmanager +def build_pd_engine_client(prefill_addr: str, decode_addr: str, + connector_addr: str): + engine = PDEngine(prefill_addr, decode_addr, connector_addr) + yield engine + engine.shutdown() class PDEngine: """ @@ -68,6 +75,7 @@ class PDEngine: * TODO: support more than just text input. * TODO: move under vllm/v1/engine one past prototype. """ + def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str): # Request queues. self.queues: Dict[str, asyncio.Queue] = {} diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 297f51ccf69e6..e14cf97c98594 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.disaggregated.engine import PDEngine +from vllm.entrypoints.disaggregated.engine import build_pd_engine_client from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, @@ -136,16 +136,15 @@ async def build_async_engine_client( args: Namespace) -> AsyncIterator[EngineClient]: # Case 1: We are running a P/D Connector. - # The Engines may be running on another node. if hasattr(args, "connector_addr"): - async with PDEngine( + async with build_pd_engine_client( prefill_addr=args.prefill_addr, decode_addr=args.decode_addr, connector_addr=args.connector_addr) as engine: yield engine engine.shutdown() - # Case 2: We are running an actual Engine from this process. + # Case 2: We are running a normal instance of vLLM. else: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit