mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-27 18:07:07 +08:00
updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
b89d89f456
commit
a8a621e419
@ -1,283 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
import uvloop
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
|
|
||||||
ZmqMsgResponse)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
|
||||||
|
|
||||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
|
||||||
logger = init_logger('vllm.entrypoints.disagg_connector')
|
|
||||||
|
|
||||||
TIME_OUT = 5
|
|
||||||
X_REQUEST_ID_KEY = "X-Request-Id"
|
|
||||||
CONTENT_TYPE_STREAM = "text/event-stream"
|
|
||||||
|
|
||||||
# communication between output handlers and execute_task_async
|
|
||||||
request_queues: dict[str, asyncio.Queue]
|
|
||||||
|
|
||||||
|
|
||||||
async def log_stats(request_queues: dict[str, asyncio.Queue]):
|
|
||||||
while True:
|
|
||||||
logger.info("Running requests: %d", len(request_queues))
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
|
|
||||||
|
|
||||||
# create async socket use ZMQ_DEALER
|
|
||||||
async def create_socket(url: str,
|
|
||||||
zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket:
|
|
||||||
socket = zmqctx.socket(zmq.DEALER)
|
|
||||||
identity = f"connector-{uuid.uuid4()}"
|
|
||||||
# unlimited HWM
|
|
||||||
hwm_limit = 0
|
|
||||||
socket.setsockopt(zmq.IDENTITY, identity.encode())
|
|
||||||
socket.setsockopt(zmq.SNDHWM, hwm_limit)
|
|
||||||
socket.setsockopt(zmq.RCVHWM, hwm_limit)
|
|
||||||
socket.connect(url)
|
|
||||||
logger.info("%s started at %s", identity, url)
|
|
||||||
return socket
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
# create socket pool with prefill and decode
|
|
||||||
logger.info("start connect zmq server")
|
|
||||||
app.state.zmqctx = zmq.asyncio.Context()
|
|
||||||
app.state.prefill_socket = await create_socket(app.state.prefill_addr,
|
|
||||||
zmqctx=app.state.zmqctx)
|
|
||||||
logger.info("success create_socke sockets_prefill")
|
|
||||||
app.state.decode_socket = await create_socket(app.state.decode_addr,
|
|
||||||
zmqctx=app.state.zmqctx)
|
|
||||||
logger.info("success create_socket sockets_decode")
|
|
||||||
global request_queues
|
|
||||||
request_queues = {}
|
|
||||||
asyncio.create_task(prefill_handler(app.state.prefill_socket))
|
|
||||||
asyncio.create_task(decode_handler(app.state.decode_socket))
|
|
||||||
asyncio.create_task(log_stats(request_queues))
|
|
||||||
yield
|
|
||||||
## close zmq context
|
|
||||||
logger.info("shutdown disagg connector")
|
|
||||||
logger.info("term zmqctx")
|
|
||||||
app.state.zmqctx.destroy(linger=0)
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/completions')
|
|
||||||
async def completions(request: CompletionRequest, raw_request: Request,
|
|
||||||
background_tasks: BackgroundTasks):
|
|
||||||
try:
|
|
||||||
# Add the X-Request-Id header to the raw headers list
|
|
||||||
header = dict(raw_request.headers)
|
|
||||||
request_id = header.get(X_REQUEST_ID_KEY)
|
|
||||||
queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
|
|
||||||
if request_id is None:
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
logger.debug("add X-Request-Id: %s", request_id)
|
|
||||||
logger.debug("X-Request-Id is: %s", request_id)
|
|
||||||
request_queues[request_id] = queue
|
|
||||||
zmq_msg_request = ZmqMsgRequest(request_id=request_id,
|
|
||||||
type="completions",
|
|
||||||
body=request)
|
|
||||||
logger.info("Received request_id: %s, request: %s, header: %s",
|
|
||||||
request_id, zmq_msg_request.model_dump_json(), header)
|
|
||||||
original_max_tokens = request.max_tokens
|
|
||||||
# change max_tokens = 1 to let it only do prefill
|
|
||||||
request.max_tokens = 1
|
|
||||||
# finish prefill
|
|
||||||
try:
|
|
||||||
prefill_response = await prefill(zmq_msg_request)
|
|
||||||
if isinstance(prefill_response, JSONResponse
|
|
||||||
) and prefill_response.status_code != HTTPStatus.OK:
|
|
||||||
return prefill_response
|
|
||||||
logger.debug("finish prefill start decode")
|
|
||||||
request.max_tokens = original_max_tokens
|
|
||||||
response = await decode(zmq_msg_request)
|
|
||||||
logger.debug("finish decode")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
|
||||||
e)
|
|
||||||
response = JSONResponse(
|
|
||||||
{"error": {
|
|
||||||
"message": str(e)
|
|
||||||
}},
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
exc_info = sys.exc_info()
|
|
||||||
logger.error("Error occurred in disagg prefill proxy server")
|
|
||||||
logger.error(e)
|
|
||||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
|
||||||
response = JSONResponse({"error": {
|
|
||||||
"message": str(e)
|
|
||||||
}}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
||||||
return response
|
|
||||||
finally:
|
|
||||||
if request_id is not None:
|
|
||||||
background_tasks.add_task(cleanup_request_id, request_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
[body] = await socket.recv_multipart()
|
|
||||||
response = ZmqMsgResponse.model_validate_json(body)
|
|
||||||
request_id = response.request_id
|
|
||||||
logger.debug("%s socket received result: %s", scene,
|
|
||||||
response.model_dump_json())
|
|
||||||
if request_id in request_queues:
|
|
||||||
request_queues[request_id].put_nowait(response)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"%s socket received but request_id not found discard: %s",
|
|
||||||
scene, request_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error("%s handler error: %s", scene, e)
|
|
||||||
|
|
||||||
|
|
||||||
# prefill handler
|
|
||||||
async def prefill_handler(prefill_socket: zmq.asyncio.Socket):
|
|
||||||
await socket_recv_handler(prefill_socket, "prefill")
|
|
||||||
|
|
||||||
|
|
||||||
# decode handler
|
|
||||||
async def decode_handler(decode_socket: zmq.asyncio.Socket):
|
|
||||||
await socket_recv_handler(decode_socket, "decode")
|
|
||||||
|
|
||||||
|
|
||||||
# select a socket and execute task
|
|
||||||
async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
|
|
||||||
socket: zmq.asyncio.Socket):
|
|
||||||
try:
|
|
||||||
request_id = zmq_msg_request.request_id
|
|
||||||
requestBody = zmq_msg_request.model_dump_json()
|
|
||||||
logger.debug("Sending requestBody: %s", requestBody)
|
|
||||||
socket.send_multipart([requestBody.encode()])
|
|
||||||
logger.debug("Sent end")
|
|
||||||
queue = request_queues[request_id]
|
|
||||||
while True:
|
|
||||||
logger.debug("Waiting for reply")
|
|
||||||
zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
|
|
||||||
queue.get(), TIME_OUT)
|
|
||||||
logger.debug("Received result: %s",
|
|
||||||
zmq_msg_response.model_dump_json())
|
|
||||||
yield zmq_msg_response
|
|
||||||
if zmq_msg_response.stop:
|
|
||||||
logger.debug("Received stop: %s", zmq_msg_response.stop)
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
|
|
||||||
finally:
|
|
||||||
logger.debug("request_id: %s, execute_task_async end", request_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
|
|
||||||
logger.debug("start prefill")
|
|
||||||
generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
|
|
||||||
async for res in generator:
|
|
||||||
logger.debug("res: %s", res)
|
|
||||||
if res.body_type == "response":
|
|
||||||
return JSONResponse(res.body)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_stream_response(
|
|
||||||
fisrt_reply: str,
|
|
||||||
generator: AsyncGenerator[ZmqMsgResponse]) -> AsyncGenerator[str]:
|
|
||||||
yield fisrt_reply
|
|
||||||
async for reply in generator:
|
|
||||||
yield reply.body
|
|
||||||
|
|
||||||
|
|
||||||
async def decode(
|
|
||||||
zmq_msg_request: ZmqMsgRequest
|
|
||||||
) -> Union[JSONResponse, StreamingResponse]:
|
|
||||||
logger.debug("start decode")
|
|
||||||
generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
|
|
||||||
|
|
||||||
async for res in generator:
|
|
||||||
logger.debug("res: %s", res)
|
|
||||||
if res.body_type == "response":
|
|
||||||
return JSONResponse(res.body)
|
|
||||||
else:
|
|
||||||
return StreamingResponse(generate_stream_response(
|
|
||||||
res.body, generator),
|
|
||||||
media_type=CONTENT_TYPE_STREAM)
|
|
||||||
|
|
||||||
# If the generator is empty, return a default error response
|
|
||||||
logger.error("No response received from generator")
|
|
||||||
return JSONResponse({"error": "No response received from generator"},
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_request_id(request_id: str):
|
|
||||||
if request_id in request_queues:
|
|
||||||
logger.info("del request_id: %s, decode finished", request_id)
|
|
||||||
del request_queues[request_id]
|
|
||||||
|
|
||||||
|
|
||||||
async def run_disagg_connector(args, **uvicorn_kwargs):
|
|
||||||
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
|
||||||
uvicorn_kwargs)
|
|
||||||
logger.info(args.prefill_addr)
|
|
||||||
app.state.port = args.port
|
|
||||||
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
|
|
||||||
app.state.decode_addr = f"ipc://{args.decode_addr}"
|
|
||||||
logger.info(
|
|
||||||
"start connect prefill_addr: %s decode_addr: %s "
|
|
||||||
"zmq server fastapi port: %s", app.state.prefill_addr,
|
|
||||||
app.state.decode_addr, app.state.port)
|
|
||||||
|
|
||||||
def signal_handler(*_) -> None:
|
|
||||||
# Interrupt server on sigterm while initializing
|
|
||||||
raise KeyboardInterrupt("terminated")
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
# init uvicorn server
|
|
||||||
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
|
|
||||||
server = uvicorn.Server(config)
|
|
||||||
await server.serve()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# NOTE(simon):
|
|
||||||
# This section should be sync with vllm/entrypoints/cli/connect.py for CLI
|
|
||||||
# entrypoints.
|
|
||||||
parser = FlexibleArgumentParser(description="vLLM disagg connect server.")
|
|
||||||
parser.add_argument("--port",
|
|
||||||
type=int,
|
|
||||||
default=8001,
|
|
||||||
help="The fastapi server port default 8001")
|
|
||||||
# security concern only support ipc now
|
|
||||||
parser.add_argument("--prefill-addr",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The zmq ipc prefill address")
|
|
||||||
parser.add_argument("--decode-addr",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The zmq ipc decode address")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
uvloop.run(run_disagg_connector(args))
|
|
||||||
0
vllm/entrypoints/disaggregated/__init__.py
Normal file
0
vllm/entrypoints/disaggregated/__init__.py
Normal file
305
vllm/entrypoints/disaggregated/connector.py
Normal file
305
vllm/entrypoints/disaggregated/connector.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import msgspec
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Dict, Mapping, Optional
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import uvloop
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
|
from vllm.entrypoints.openai.api_server import run_server
|
||||||
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
|
validate_parsed_serve_args)
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_TOKENS = 32000
|
||||||
|
|
||||||
|
# NOTE FOR DEVELOPERS:
|
||||||
|
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
||||||
|
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
||||||
|
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
||||||
|
|
||||||
|
class PDRequest(msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False): # type: ignore[call-arg]
|
||||||
|
request_id: str
|
||||||
|
prompt: str
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
# TODO: support multimodal inputs.
|
||||||
|
|
||||||
|
class PDResponse(msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False): # type: ignore[call-arg]
|
||||||
|
request_id: str
|
||||||
|
success: bool
|
||||||
|
delta_text: Optional[str] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[str] = None
|
||||||
|
logprobs = None # TODO
|
||||||
|
|
||||||
|
|
||||||
|
class PDEngine:
|
||||||
|
"""
|
||||||
|
PDEngine:
|
||||||
|
Equiavlent of AsyncLLM for P/D. Assumes there is
|
||||||
|
a Prefill and Decode service already running.
|
||||||
|
|
||||||
|
* TODO: actually handle errors and failure.
|
||||||
|
* TODO: support more than just text input.
|
||||||
|
* TODO: move under vllm/v1/engine one past prototype.
|
||||||
|
"""
|
||||||
|
def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str):
|
||||||
|
# Request queues.
|
||||||
|
self.queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
|
||||||
|
# Serialization encoder.
|
||||||
|
self.encoder = msgspec.msgpack.Encoder()
|
||||||
|
|
||||||
|
# ZMQ communication..
|
||||||
|
self.ctx = zmq.asyncio.Context()
|
||||||
|
self.to_decode = make_zmq_socket(
|
||||||
|
self.ctx, f"{decode_addr}", zmq.constants.PUSH)
|
||||||
|
self.to_prefill = make_zmq_socket(
|
||||||
|
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
|
||||||
|
self.connector_addr = connector_addr
|
||||||
|
|
||||||
|
# Background loops (started on first generate()).
|
||||||
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
|
self.log_running: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
if (ctx := self.ctx) is not None:
|
||||||
|
ctx.destroy(linger=0)
|
||||||
|
if (task := self.log_running) is not None:
|
||||||
|
task.cancel()
|
||||||
|
if (task := self.output_handler) is not None:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
async def _run_log_running(self):
|
||||||
|
logger.info("Running requests: %d", len(self.queues))
|
||||||
|
await asyncio.sleep(10.)
|
||||||
|
|
||||||
|
async def _run_output_handler(self, socket: zmq.asyncio.Socket):
|
||||||
|
"""
|
||||||
|
Pull responses from Decode + Prefill engines and
|
||||||
|
distribute back to the generate() tasks.
|
||||||
|
"""
|
||||||
|
decoder = msgspec.msgpack.Decoder(PDResponse)
|
||||||
|
|
||||||
|
socket: Optional[zmq.asyncio.Socket] = None
|
||||||
|
try:
|
||||||
|
socket = make_zmq_socket(
|
||||||
|
self.ctx, self.connector_addr, zmq.constants.PULL)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
reponse_bytes = await socket.recv().buffer
|
||||||
|
response = decoder.decode(reponse_bytes)
|
||||||
|
self.queues[response.request_id].put_nowait(response)
|
||||||
|
except:
|
||||||
|
# TODO: actually handle failure and shutdown.
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if socket is not None:
|
||||||
|
socket.close(linger=0)
|
||||||
|
|
||||||
|
async def _prefill(self,
|
||||||
|
request: PDRequest,
|
||||||
|
q: asyncio.Queue[PDResponse]) -> PDResponse:
|
||||||
|
# Send request to the prefill instance.
|
||||||
|
req_bytes = self.encoder(request)
|
||||||
|
await self.to_prefill.send(req_bytes, copy=False)
|
||||||
|
|
||||||
|
# Wait for the prefill to be done.
|
||||||
|
response = await q.get()
|
||||||
|
assert response.request_id == request.request_id
|
||||||
|
if not response.success:
|
||||||
|
# TODO: actual error handling and shutdown.
|
||||||
|
raise Exception("Failed Prefill Request.")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _decode(self,
|
||||||
|
request: PDRequest,
|
||||||
|
q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]:
|
||||||
|
# Send request to the decode instance.
|
||||||
|
req_bytes = self.encoder(request)
|
||||||
|
await self.to_decode.send(req_bytes, copy=False)
|
||||||
|
|
||||||
|
# Iterate response queue and yield each response to caller..
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
response = await q.get()
|
||||||
|
if not response.success:
|
||||||
|
# TODO: actual error handling and shutdown.
|
||||||
|
raise Exception("Failed Decode Request.")
|
||||||
|
finished = response.finish_reason is not None
|
||||||
|
yield response
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> AsyncGenerator[PDResponse]:
|
||||||
|
# Start loops on first request.
|
||||||
|
if self.output_handler is None:
|
||||||
|
self.output_handler = asyncio.create_task(self._run_output_handler())
|
||||||
|
self.log_running = asyncio.create_task(self._run_log_running())
|
||||||
|
|
||||||
|
# TODO: expand to suppo
|
||||||
|
if not isinstance(prompt, str):
|
||||||
|
raise ValueError("We currently only support text inputs!")
|
||||||
|
if request_id in self.queues:
|
||||||
|
raise ValueError(f"Found duplicate request_id: {request_id}!")
|
||||||
|
|
||||||
|
# Queue to gather output from output_handler.
|
||||||
|
q: asyncio.Queue[PDResponse] = asyncio.Queue()
|
||||||
|
self.queues[request_id] = q
|
||||||
|
|
||||||
|
# (1) Perform the prefill (max_tokens=1).
|
||||||
|
original_max_tokens = sampling_params.max_tokens
|
||||||
|
request = PDRequest(request_id, prompt, sampling_params)
|
||||||
|
request.sampling_params.max_tokens = 1
|
||||||
|
response = await self._prefill(request, q)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# (2) Perform the decodes (original tokens).
|
||||||
|
request.sampling_params.max_tokens = original_max_tokens
|
||||||
|
async for response in self._decode(request, q):
|
||||||
|
yield response
|
||||||
|
|
||||||
|
async def beam_search(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
request_id: str,
|
||||||
|
params: BeamSearchParams,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
pooling_params: PoolingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def abort(self, request_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_model_config(self) -> ModelConfig:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_tokenizer(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> AnyTokenizer:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def is_tracing_enabled(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def do_log_stats(
|
||||||
|
self,
|
||||||
|
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||||
|
model_output: Optional[List[SamplerOutput]] = None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def check_health(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def start_profile(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def stop_profile(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reset_prefix_cache(self,
|
||||||
|
device: Optional[Device] = None) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def sleep(self, level: int = 1) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def wake_up(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def is_sleeping(self) -> bool:
|
||||||
|
False
|
||||||
|
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
async def run_disagg_connector(args, **uvicorn_kwargs):
|
||||||
|
logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs)
|
||||||
|
|
||||||
|
# NOTE FOR DEVELOPERS: when we shift this to TCP, we must
|
||||||
|
# ensure that the serialization is not pickle based to
|
||||||
|
# avoid RCE issues from untrusted users!!!
|
||||||
|
app.state.port = args.port
|
||||||
|
app.state.connector_addr = f"ipc://{args.connector_addr}"
|
||||||
|
app.state.decode_addr = f"ipc://{args.decode_addr}"
|
||||||
|
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
|
||||||
|
|
||||||
|
# init uvicorn server
|
||||||
|
config = uvicorn.Config(app, host=args.post, port=args.port)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
|
parser.add_argument("--connector-addr",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The zmq ipc connector address")
|
||||||
|
parser.add_argument("--prefill-addr",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The zmq ipc prefill address")
|
||||||
|
parser.add_argument("--decode-addr",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The zmq ipc decode address")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
uvloop.run(run_server(args))
|
||||||
Loading…
x
Reference in New Issue
Block a user