Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2025-03-22 15:58:51 -04:00
parent 2ceb7bc534
commit 120bbdfd82
4 changed files with 186 additions and 53 deletions

View File

@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
import uvicorn
import uvloop
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.disaggregated.pd_engine import PDEngine
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
from vllm.entrypoints.openai.protocol import (
CompletionResponse, ErrorResponse)
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
app = FastAPI()
@app.get("/v1/models")
async def show_available_models(raw_request: Request):
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@asynccontextmanager
async def pd_engine_client_ctx_manager(
model_name: str,
prefill_addr: str,
decode_addr: str,
connector_addr: str) -> AsyncIterator[PDEngine]:
engine = PDEngine(model_name, prefill_addr, decode_addr, connector_addr)
yield engine
engine.shutdown()
async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector Start %s %s", args,
uvicorn_kwargs)
prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}"
connector_addr = f"ipc://{args.connector_addr}"
with pd_engine_client_ctx_manager(
args.model, prefill_addr, decode_addr, connector_addr) as engine_client:
model_config = await engine_client.get_model_config()
# Models.
app.state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=[BaseModelPath(
name=args.served_model_name or args.model,
model_path=args.model)
],
)
# Completions.
app.state.openai_serving_completion = OpenAIServingCompletion(
engine_client=engine_client,
model_config=model_config,
models=app.state.openai_serving_models,
request_logger=None,
)
# Init Uvicorn Server. Server.
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible P/D Server.")
parser.add_argument("--host",
type=str,
default="0.0.0.0",
help="The host of the HTTP server.")
parser.add_argument("--port",
type=int,
default=8001,
help="The port of the HTTP server.")
parser.add_argument("--model",
type=str,
required=True,
help="The path to the model.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The served name of the model.")
parser.add_argument("--connector-addr",
type=str,
required=True,
help="The zmq ipc connector address")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
args = parser.parse_args()
uvloop.run(main(args))

View File

@ -3,10 +3,8 @@
import asyncio
import msgspec
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Dict, Mapping, Optional
from typing import Dict, List, Mapping, Optional
import uvicorn
import uvloop
import zmq
import zmq.asyncio
@ -27,6 +25,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
logger = init_logger(__name__)
@ -43,7 +42,7 @@ class PDRequest(msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt: str
prompt_token_ids: List[int]
sampling_params: SamplingParams
# TODO: support multimodal inputs.
@ -58,12 +57,6 @@ class PDResponse(msgspec.Struct,
stop_reason: Optional[str] = None
logprobs = None # TODO
@asynccontextmanager
def build_pd_engine_client(prefill_addr: str, decode_addr: str,
connector_addr: str):
engine = PDEngine(prefill_addr, decode_addr, connector_addr)
yield engine
engine.shutdown()
class PDEngine:
"""
@ -76,7 +69,13 @@ class PDEngine:
* TODO: move under vllm/v1/engine one past prototype.
"""
def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str):
def __init__(
self,
prefill_addr: str,
decode_addr: str,
connector_addr: str,
model_name: str
):
# Request queues.
self.queues: Dict[str, asyncio.Queue] = {}
@ -95,6 +94,32 @@ class PDEngine:
self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
self.model_config = ModelConfig(
model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="auto",
seed=42
)
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
max_num_seqs=1024,
max_loras=0,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side)
self.tokenizer = TokenizerGroup(**init_kwargs)
def shutdown(self):
if (ctx := self.ctx) is not None:
ctx.destroy(linger=0)
@ -178,9 +203,22 @@ class PDEngine:
self.output_handler = asyncio.create_task(self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running())
# TODO: expand to suppo
# TODO: Expand to support the full matrix.
if not isinstance(prompt, str):
raise ValueError("We currently only support text inputs!")
raise NotImplementedError(
"We currently only support text for P/D!")
if lora_request is not None:
raise NotImplementedError(
"We currently do not suppport LoRA for P/D!")
if trace_headers is not None:
raise NotImplementedError(
"We currently do not suppport tracing for P/D!")
if prompt_adapter_request is not None:
raise NotImplementedError(
"We currently do not suppport prompt adapter for P/D!")
if priority != 0:
raise NotImplementedError(
"We currently do not support priority for P/D!")
if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!")
@ -188,14 +226,14 @@ class PDEngine:
q: asyncio.Queue[PDResponse] = asyncio.Queue()
self.queues[request_id] = q
# (1) Perform the prefill (max_tokens=1).
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
request = PDRequest(request_id, prompt, sampling_params)
request.sampling_params.max_tokens = 1
response = await self._prefill(request, q)
yield response
# (2) Perform the decodes (original tokens).
# (2) Perform the Decodes.
request.sampling_params.max_tokens = original_max_tokens
async for response in self._decode(request, q):
yield response
@ -223,7 +261,7 @@ class PDEngine:
raise NotImplementedError
async def get_model_config(self) -> ModelConfig:
raise NotImplementedError
return self.model_config
async def get_decoding_config(self) -> DecodingConfig:
raise NotImplementedError
@ -235,7 +273,10 @@ class PDEngine:
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
raise NotImplementedError
if lora_request is not None:
raise NotImplementedError(
"LoRA is not yet supported in the PDEngine.")
return self.tokenizer.get_lora_tokenizer(None)
async def is_tracing_enabled(self) -> bool:
return False
@ -271,23 +312,6 @@ class PDEngine:
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError
async def run_disagg_connector(args, **uvicorn_kwargs):
logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs)
# NOTE FOR DEVELOPERS: when we shift this to TCP, we must
# ensure that the serialization is not pickle based to
# avoid RCE issues from untrusted users!!!
app.state.port = args.port
app.state.connector_addr = f"ipc://{args.connector_addr}"
app.state.decode_addr = f"ipc://{args.decode_addr}"
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
# init uvicorn server
config = uvicorn.Config(app, host=args.post, port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":

View File

@ -36,7 +36,6 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.disaggregated.engine import build_pd_engine_client
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@ -135,30 +134,18 @@ async def lifespan(app: FastAPI):
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
# Case 1: We are running a P/D Connector.
if hasattr(args, "connector_addr"):
async with build_pd_engine_client(
prefill_addr=args.prefill_addr,
decode_addr=args.decode_addr,
connector_addr=args.connector_addr) as engine:
yield engine
engine.shutdown()
# Case 2: We are running a normal instance of vLLM.
else:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
yield engine
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
deploy_disagg_connector: bool = False,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either: