mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-29 07:57:08 +08:00
updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
2ceb7bc534
commit
120bbdfd82
122
vllm/entrypoints/disaggregated/api_server.py
Normal file
122
vllm/entrypoints/disaggregated/api_server.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import uvloop
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
from vllm.entrypoints.disaggregated.pd_engine import PDEngine
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
|
OpenAIServingModels)
|
||||||
|
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
CompletionResponse, ErrorResponse)
|
||||||
|
|
||||||
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
|
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def show_available_models(raw_request: Request):
|
||||||
|
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
|
||||||
|
models_ = await handler.show_available_models()
|
||||||
|
return JSONResponse(content=models_.model_dump())
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
|
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion
|
||||||
|
generator = await handler.create_completion(request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, CompletionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def pd_engine_client_ctx_manager(
|
||||||
|
model_name: str,
|
||||||
|
prefill_addr: str,
|
||||||
|
decode_addr: str,
|
||||||
|
connector_addr: str) -> AsyncIterator[PDEngine]:
|
||||||
|
engine = PDEngine(model_name, prefill_addr, decode_addr, connector_addr)
|
||||||
|
yield engine
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
async def main(args, **uvicorn_kwargs):
|
||||||
|
logger.info("vLLM Disaggregate Connector Start %s %s", args,
|
||||||
|
uvicorn_kwargs)
|
||||||
|
|
||||||
|
prefill_addr = f"ipc://{args.prefill_addr}"
|
||||||
|
decode_addr = f"ipc://{args.decode_addr}"
|
||||||
|
connector_addr = f"ipc://{args.connector_addr}"
|
||||||
|
|
||||||
|
with pd_engine_client_ctx_manager(
|
||||||
|
args.model, prefill_addr, decode_addr, connector_addr) as engine_client:
|
||||||
|
|
||||||
|
model_config = await engine_client.get_model_config()
|
||||||
|
|
||||||
|
# Models.
|
||||||
|
app.state.openai_serving_models = OpenAIServingModels(
|
||||||
|
engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
base_model_paths=[BaseModelPath(
|
||||||
|
name=args.served_model_name or args.model,
|
||||||
|
model_path=args.model)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Completions.
|
||||||
|
app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
models=app.state.openai_serving_models,
|
||||||
|
request_logger=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init Uvicorn Server. Server.
|
||||||
|
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible P/D Server.")
|
||||||
|
parser.add_argument("--host",
|
||||||
|
type=str,
|
||||||
|
default="0.0.0.0",
|
||||||
|
help="The host of the HTTP server.")
|
||||||
|
parser.add_argument("--port",
|
||||||
|
type=int,
|
||||||
|
default=8001,
|
||||||
|
help="The port of the HTTP server.")
|
||||||
|
parser.add_argument("--model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The path to the model.")
|
||||||
|
parser.add_argument("--served-model-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The served name of the model.")
|
||||||
|
parser.add_argument("--connector-addr",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The zmq ipc connector address")
|
||||||
|
parser.add_argument("--prefill-addr",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The zmq ipc prefill address")
|
||||||
|
parser.add_argument("--decode-addr",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The zmq ipc decode address")
|
||||||
|
args = parser.parse_args()
|
||||||
|
uvloop.run(main(args))
|
||||||
@ -3,10 +3,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import msgspec
|
import msgspec
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from typing import Dict, List, Mapping, Optional
|
||||||
from typing import Dict, Mapping, Optional
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -27,6 +25,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
|
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -43,7 +42,7 @@ class PDRequest(msgspec.Struct,
|
|||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
gc=False): # type: ignore[call-arg]
|
gc=False): # type: ignore[call-arg]
|
||||||
request_id: str
|
request_id: str
|
||||||
prompt: str
|
prompt_token_ids: List[int]
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
# TODO: support multimodal inputs.
|
# TODO: support multimodal inputs.
|
||||||
|
|
||||||
@ -58,12 +57,6 @@ class PDResponse(msgspec.Struct,
|
|||||||
stop_reason: Optional[str] = None
|
stop_reason: Optional[str] = None
|
||||||
logprobs = None # TODO
|
logprobs = None # TODO
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
def build_pd_engine_client(prefill_addr: str, decode_addr: str,
|
|
||||||
connector_addr: str):
|
|
||||||
engine = PDEngine(prefill_addr, decode_addr, connector_addr)
|
|
||||||
yield engine
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
class PDEngine:
|
class PDEngine:
|
||||||
"""
|
"""
|
||||||
@ -76,7 +69,13 @@ class PDEngine:
|
|||||||
* TODO: move under vllm/v1/engine one past prototype.
|
* TODO: move under vllm/v1/engine one past prototype.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefill_addr: str,
|
||||||
|
decode_addr: str,
|
||||||
|
connector_addr: str,
|
||||||
|
model_name: str
|
||||||
|
):
|
||||||
# Request queues.
|
# Request queues.
|
||||||
self.queues: Dict[str, asyncio.Queue] = {}
|
self.queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
|
||||||
@ -95,6 +94,32 @@ class PDEngine:
|
|||||||
self.output_handler: Optional[asyncio.Task] = None
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
self.log_running: Optional[asyncio.Task] = None
|
self.log_running: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# Dummy: needed for EngineClient Protocol.
|
||||||
|
# TODO: refactor EngineClient to avoid needing this.
|
||||||
|
self.model_config = ModelConfig(
|
||||||
|
model=model_name,
|
||||||
|
tokenizer=model_name,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
dtype="auto",
|
||||||
|
seed=42
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dummy: needed for EngineClient Protocol.
|
||||||
|
# TODO: refactor EngineClient to avoid needing this.
|
||||||
|
init_kwargs = dict(
|
||||||
|
tokenizer_id=self.model_config.tokenizer,
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1024,
|
||||||
|
max_loras=0,
|
||||||
|
max_input_length=None,
|
||||||
|
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
revision=self.model_config.tokenizer_revision,
|
||||||
|
truncation_side=self.model_config.truncation_side)
|
||||||
|
self.tokenizer = TokenizerGroup(**init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if (ctx := self.ctx) is not None:
|
if (ctx := self.ctx) is not None:
|
||||||
ctx.destroy(linger=0)
|
ctx.destroy(linger=0)
|
||||||
@ -178,9 +203,22 @@ class PDEngine:
|
|||||||
self.output_handler = asyncio.create_task(self._run_output_handler())
|
self.output_handler = asyncio.create_task(self._run_output_handler())
|
||||||
self.log_running = asyncio.create_task(self._run_log_running())
|
self.log_running = asyncio.create_task(self._run_log_running())
|
||||||
|
|
||||||
# TODO: expand to suppo
|
# TODO: Expand to support the full matrix.
|
||||||
if not isinstance(prompt, str):
|
if not isinstance(prompt, str):
|
||||||
raise ValueError("We currently only support text inputs!")
|
raise NotImplementedError(
|
||||||
|
"We currently only support text for P/D!")
|
||||||
|
if lora_request is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not suppport LoRA for P/D!")
|
||||||
|
if trace_headers is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not suppport tracing for P/D!")
|
||||||
|
if prompt_adapter_request is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not suppport prompt adapter for P/D!")
|
||||||
|
if priority != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"We currently do not support priority for P/D!")
|
||||||
if request_id in self.queues:
|
if request_id in self.queues:
|
||||||
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||||
|
|
||||||
@ -188,14 +226,14 @@ class PDEngine:
|
|||||||
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
||||||
self.queues[request_id] = q
|
self.queues[request_id] = q
|
||||||
|
|
||||||
# (1) Perform the prefill (max_tokens=1).
|
# (1) Perform the Prefill.
|
||||||
original_max_tokens = sampling_params.max_tokens
|
original_max_tokens = sampling_params.max_tokens
|
||||||
request = PDRequest(request_id, prompt, sampling_params)
|
request = PDRequest(request_id, prompt, sampling_params)
|
||||||
request.sampling_params.max_tokens = 1
|
request.sampling_params.max_tokens = 1
|
||||||
response = await self._prefill(request, q)
|
response = await self._prefill(request, q)
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
# (2) Perform the decodes (original tokens).
|
# (2) Perform the Decodes.
|
||||||
request.sampling_params.max_tokens = original_max_tokens
|
request.sampling_params.max_tokens = original_max_tokens
|
||||||
async for response in self._decode(request, q):
|
async for response in self._decode(request, q):
|
||||||
yield response
|
yield response
|
||||||
@ -223,7 +261,7 @@ class PDEngine:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
raise NotImplementedError
|
return self.model_config
|
||||||
|
|
||||||
async def get_decoding_config(self) -> DecodingConfig:
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -235,7 +273,10 @@ class PDEngine:
|
|||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> AnyTokenizer:
|
) -> AnyTokenizer:
|
||||||
raise NotImplementedError
|
if lora_request is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"LoRA is not yet supported in the PDEngine.")
|
||||||
|
return self.tokenizer.get_lora_tokenizer(None)
|
||||||
|
|
||||||
async def is_tracing_enabled(self) -> bool:
|
async def is_tracing_enabled(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@ -273,23 +314,6 @@ class PDEngine:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
async def run_disagg_connector(args, **uvicorn_kwargs):
|
|
||||||
logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs)
|
|
||||||
|
|
||||||
# NOTE FOR DEVELOPERS: when we shift this to TCP, we must
|
|
||||||
# ensure that the serialization is not pickle based to
|
|
||||||
# avoid RCE issues from untrusted users!!!
|
|
||||||
app.state.port = args.port
|
|
||||||
app.state.connector_addr = f"ipc://{args.connector_addr}"
|
|
||||||
app.state.decode_addr = f"ipc://{args.decode_addr}"
|
|
||||||
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
|
|
||||||
|
|
||||||
# init uvicorn server
|
|
||||||
config = uvicorn.Config(app, host=args.post, port=args.port)
|
|
||||||
server = uvicorn.Server(config)
|
|
||||||
await server.serve()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
@ -36,7 +36,6 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
|||||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import load_chat_template
|
from vllm.entrypoints.chat_utils import load_chat_template
|
||||||
from vllm.entrypoints.disaggregated.engine import build_pd_engine_client
|
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
@ -135,30 +134,18 @@ async def lifespan(app: FastAPI):
|
|||||||
async def build_async_engine_client(
|
async def build_async_engine_client(
|
||||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||||
|
|
||||||
# Case 1: We are running a P/D Connector.
|
# Context manager to handle engine_client lifecycle
|
||||||
if hasattr(args, "connector_addr"):
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
async with build_pd_engine_client(
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
prefill_addr=args.prefill_addr,
|
async with build_async_engine_client_from_engine_args(
|
||||||
decode_addr=args.decode_addr,
|
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||||
connector_addr=args.connector_addr) as engine:
|
yield engine
|
||||||
yield engine
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
# Case 2: We are running a normal instance of vLLM.
|
|
||||||
else:
|
|
||||||
# Context manager to handle engine_client lifecycle
|
|
||||||
# Ensures everything is shutdown and cleaned up on error/exit
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
|
||||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def build_async_engine_client_from_engine_args(
|
async def build_async_engine_client_from_engine_args(
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
deploy_disagg_connector: bool = False,
|
|
||||||
) -> AsyncIterator[EngineClient]:
|
) -> AsyncIterator[EngineClient]:
|
||||||
"""
|
"""
|
||||||
Create EngineClient, either:
|
Create EngineClient, either:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user