diff --git a/vllm/entrypoints/disaggregated/api_server.py b/vllm/entrypoints/disaggregated/api_server.py new file mode 100644 index 0000000000000..95c5e5df79a1d --- /dev/null +++ b/vllm/entrypoints/disaggregated/api_server.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +import uvicorn +import uvloop + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.entrypoints.disaggregated.pd_engine import PDEngine +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser +from vllm.entrypoints.openai.protocol import ( + CompletionResponse, ErrorResponse) + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.disaggregated.api_server') + +app = FastAPI() + +@app.get("/v1/models") +async def show_available_models(raw_request: Request): + handler: OpenAIServingModels = raw_request.app.state.openai_serving_models + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion + generator = await handler.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + +@asynccontextmanager +async def pd_engine_client_ctx_manager( + model_name: str, + prefill_addr: str, + decode_addr: str, + connector_addr: str) -> AsyncIterator[PDEngine]: + engine = PDEngine(model_name, prefill_addr, decode_addr, connector_addr) + yield engine + engine.shutdown() + +async def main(args, **uvicorn_kwargs): + logger.info("vLLM Disaggregate Connector Start %s %s", args, + uvicorn_kwargs) + + prefill_addr = f"ipc://{args.prefill_addr}" + decode_addr = f"ipc://{args.decode_addr}" + connector_addr = f"ipc://{args.connector_addr}" + + with pd_engine_client_ctx_manager( + args.model, prefill_addr, decode_addr, connector_addr) as engine_client: + + model_config = await engine_client.get_model_config() + + # Models. + app.state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=[BaseModelPath( + name=args.served_model_name or args.model, + model_path=args.model) + ], + ) + + # Completions. + app.state.openai_serving_completion = OpenAIServingCompletion( + engine_client=engine_client, + model_config=model_config, + models=app.state.openai_serving_models, + request_logger=None, + ) + + # Init Uvicorn Server. Server. + config = uvicorn.Config(app, host="0.0.0.0", port=args.port) + server = uvicorn.Server(config) + await server.serve() + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible P/D Server.") + parser.add_argument("--host", + type=str, + default="0.0.0.0", + help="The host of the HTTP server.") + parser.add_argument("--port", + type=int, + default=8001, + help="The port of the HTTP server.") + parser.add_argument("--model", + type=str, + required=True, + help="The path to the model.") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The served name of the model.") + parser.add_argument("--connector-addr", + type=str, + required=True, + help="The zmq ipc connector address") + parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The zmq ipc prefill address") + parser.add_argument("--decode-addr", + type=str, + required=True, + help="The zmq ipc decode address") + args = parser.parse_args() + uvloop.run(main(args)) diff --git a/vllm/entrypoints/disaggregated/engine.py b/vllm/entrypoints/disaggregated/pd_engine.py similarity index 80% rename from vllm/entrypoints/disaggregated/engine.py rename to vllm/entrypoints/disaggregated/pd_engine.py index 162eb77de1a5c..0efed4e3b973b 100644 --- a/vllm/entrypoints/disaggregated/engine.py +++ b/vllm/entrypoints/disaggregated/pd_engine.py @@ -3,10 +3,8 @@ import asyncio import msgspec from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from typing import Dict, Mapping, Optional +from typing import Dict, List, Mapping, Optional -import uvicorn import uvloop import zmq import zmq.asyncio @@ -27,6 +25,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket logger = init_logger(__name__) @@ -43,7 +42,7 @@ class PDRequest(msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] request_id: str - prompt: str + prompt_token_ids: List[int] sampling_params: SamplingParams # TODO: support multimodal inputs. @@ -58,12 +57,6 @@ class PDResponse(msgspec.Struct, stop_reason: Optional[str] = None logprobs = None # TODO -@asynccontextmanager -def build_pd_engine_client(prefill_addr: str, decode_addr: str, - connector_addr: str): - engine = PDEngine(prefill_addr, decode_addr, connector_addr) - yield engine - engine.shutdown() class PDEngine: """ @@ -76,7 +69,13 @@ class PDEngine: * TODO: move under vllm/v1/engine one past prototype. """ - def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str): + def __init__( + self, + prefill_addr: str, + decode_addr: str, + connector_addr: str, + model_name: str + ): # Request queues. self.queues: Dict[str, asyncio.Queue] = {} @@ -95,6 +94,32 @@ class PDEngine: self.output_handler: Optional[asyncio.Task] = None self.log_running: Optional[asyncio.Task] = None + # Dummy: needed for EngineClient Protocol. + # TODO: refactor EngineClient to avoid needing this. + self.model_config = ModelConfig( + model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="auto", + seed=42 + ) + + # Dummy: needed for EngineClient Protocol. + # TODO: refactor EngineClient to avoid needing this. + init_kwargs = dict( + tokenizer_id=self.model_config.tokenizer, + enable_lora=False, + max_num_seqs=1024, + max_loras=0, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision, + truncation_side=self.model_config.truncation_side) + self.tokenizer = TokenizerGroup(**init_kwargs) + + def shutdown(self): if (ctx := self.ctx) is not None: ctx.destroy(linger=0) @@ -178,9 +203,22 @@ class PDEngine: self.output_handler = asyncio.create_task(self._run_output_handler()) self.log_running = asyncio.create_task(self._run_log_running()) - # TODO: expand to suppo + # TODO: Expand to support the full matrix. if not isinstance(prompt, str): - raise ValueError("We currently only support text inputs!") + raise NotImplementedError( + "We currently only support text for P/D!") + if lora_request is not None: + raise NotImplementedError( + "We currently do not suppport LoRA for P/D!") + if trace_headers is not None: + raise NotImplementedError( + "We currently do not suppport tracing for P/D!") + if prompt_adapter_request is not None: + raise NotImplementedError( + "We currently do not suppport prompt adapter for P/D!") + if priority != 0: + raise NotImplementedError( + "We currently do not support priority for P/D!") if request_id in self.queues: raise ValueError(f"Found duplicate request_id: {request_id}!") @@ -188,14 +226,14 @@ class PDEngine: q: asyncio.Queue[PDResponse] = asyncio.Queue() self.queues[request_id] = q - # (1) Perform the prefill (max_tokens=1). + # (1) Perform the Prefill. original_max_tokens = sampling_params.max_tokens request = PDRequest(request_id, prompt, sampling_params) request.sampling_params.max_tokens = 1 response = await self._prefill(request, q) yield response - # (2) Perform the decodes (original tokens). + # (2) Perform the Decodes. request.sampling_params.max_tokens = original_max_tokens async for response in self._decode(request, q): yield response @@ -223,7 +261,7 @@ class PDEngine: raise NotImplementedError async def get_model_config(self) -> ModelConfig: - raise NotImplementedError + return self.model_config async def get_decoding_config(self) -> DecodingConfig: raise NotImplementedError @@ -235,7 +273,10 @@ class PDEngine: self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - raise NotImplementedError + if lora_request is not None: + raise NotImplementedError( + "LoRA is not yet supported in the PDEngine.") + return self.tokenizer.get_lora_tokenizer(None) async def is_tracing_enabled(self) -> bool: return False @@ -271,23 +312,6 @@ class PDEngine: async def add_lora(self, lora_request: LoRARequest) -> None: raise NotImplementedError - - -async def run_disagg_connector(args, **uvicorn_kwargs): - logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs) - - # NOTE FOR DEVELOPERS: when we shift this to TCP, we must - # ensure that the serialization is not pickle based to - # avoid RCE issues from untrusted users!!! - app.state.port = args.port - app.state.connector_addr = f"ipc://{args.connector_addr}" - app.state.decode_addr = f"ipc://{args.decode_addr}" - app.state.prefill_addr = f"ipc://{args.prefill_addr}" - - # init uvicorn server - config = uvicorn.Config(app, host=args.post, port=args.port) - server = uvicorn.Server(config) - await server.serve() if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/zmq_server.py b/vllm/entrypoints/disaggregated/worker_server.py similarity index 100% rename from vllm/entrypoints/openai/zmq_server.py rename to vllm/entrypoints/disaggregated/worker_server.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e14cf97c98594..41c4dd32442ca 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -36,7 +36,6 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.disaggregated.engine import build_pd_engine_client from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, @@ -135,30 +134,18 @@ async def lifespan(app: FastAPI): async def build_async_engine_client( args: Namespace) -> AsyncIterator[EngineClient]: - # Case 1: We are running a P/D Connector. - if hasattr(args, "connector_addr"): - async with build_pd_engine_client( - prefill_addr=args.prefill_addr, - decode_addr=args.decode_addr, - connector_addr=args.connector_addr) as engine: - yield engine - engine.shutdown() - - # Case 2: We are running a normal instance of vLLM. - else: - # Context manager to handle engine_client lifecycle - # Ensures everything is shutdown and cleaned up on error/exit - engine_args = AsyncEngineArgs.from_cli_args(args) - async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing) as engine: - yield engine + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine @asynccontextmanager async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, - deploy_disagg_connector: bool = False, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: