mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 17:57:05 +08:00
refactor zmq msg to object
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
912031ceb5
commit
d35dace985
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Union
|
||||
|
||||
import uvicorn
|
||||
@ -17,9 +17,8 @@ import zmq.asyncio
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR,
|
||||
CONTENT_TYPE_JSON,
|
||||
CONTENT_TYPE_STREAM)
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
|
||||
ZmqMsgResponse)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -28,6 +27,7 @@ logger = init_logger('vllm.entrypoints.disagg_connector')
|
||||
|
||||
TIME_OUT = 5
|
||||
X_REQUEST_ID_KEY = "X-Request-Id"
|
||||
CONTENT_TYPE_STREAM = "text/event-stream"
|
||||
|
||||
# communication between output handlers and execute_task_async
|
||||
request_queues: dict[str, asyncio.Queue]
|
||||
@ -77,40 +77,44 @@ app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post('/v1/completions')
|
||||
async def completions(request: Request, background_tasks: BackgroundTasks):
|
||||
async def completions(request: CompletionRequest, raw_request: Request,
|
||||
background_tasks: BackgroundTasks):
|
||||
try:
|
||||
# Add the X-Request-Id header to the raw headers list
|
||||
header = dict(request.headers)
|
||||
header = dict(raw_request.headers)
|
||||
request_id = header.get(X_REQUEST_ID_KEY)
|
||||
queue = asyncio.Queue()
|
||||
queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
logger.debug("add X-Request-Id: %s", request_id)
|
||||
header[X_REQUEST_ID_KEY] = request_id
|
||||
logger.debug("X-Request-Id is: %s", request_id)
|
||||
request_queues[request_id] = queue
|
||||
request_data = await request.json()
|
||||
zmq_msg_request = ZmqMsgRequest(request_id=request_id,
|
||||
type="completions",
|
||||
body=request)
|
||||
logger.info("Received request_id: %s, request: %s, header: %s",
|
||||
request_id, request_data, header)
|
||||
original_max_tokens = request_data['max_tokens']
|
||||
request_id, zmq_msg_request.model_dump_json(), header)
|
||||
original_max_tokens = request.max_tokens
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
request_data['max_tokens'] = 1
|
||||
request.max_tokens = 1
|
||||
# finish prefill
|
||||
try:
|
||||
prefill_response = await prefill(header, request_data)
|
||||
if isinstance(prefill_response, JSONResponse):
|
||||
prefill_response = await prefill(zmq_msg_request)
|
||||
if isinstance(prefill_response, JSONResponse
|
||||
) and prefill_response.status_code != HTTPStatus.OK:
|
||||
return prefill_response
|
||||
logger.debug("finish prefill start decode")
|
||||
request_data['max_tokens'] = original_max_tokens
|
||||
response = await decode(header, request_data)
|
||||
request.max_tokens = original_max_tokens
|
||||
response = await decode(zmq_msg_request)
|
||||
logger.debug("finish decode")
|
||||
except Exception as e:
|
||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
||||
e)
|
||||
response = JSONResponse({"error": {
|
||||
"message": str(e)
|
||||
}},
|
||||
status_code=500)
|
||||
response = JSONResponse(
|
||||
{"error": {
|
||||
"message": str(e)
|
||||
}},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@ -120,37 +124,27 @@ async def completions(request: Request, background_tasks: BackgroundTasks):
|
||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||
response = JSONResponse({"error": {
|
||||
"message": str(e)
|
||||
}},
|
||||
status_code=500)
|
||||
}}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return response
|
||||
finally:
|
||||
background_tasks.add_task(cleanup_request_id, request_id)
|
||||
if request_id is not None:
|
||||
background_tasks.add_task(cleanup_request_id, request_id)
|
||||
|
||||
|
||||
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
|
||||
while True:
|
||||
try:
|
||||
[request_id, contentType, reply] = await socket.recv_multipart()
|
||||
contentType_str = contentType.decode()
|
||||
reply_str = reply.decode()
|
||||
request_id_str = request_id.decode()
|
||||
logger.debug(
|
||||
"%s socket received result contentType: %s, "
|
||||
"request_id: %s, reply: %s", scene, contentType_str,
|
||||
request_id_str, reply_str)
|
||||
if request_id_str in request_queues:
|
||||
request_queues[request_id_str].put_nowait(
|
||||
(contentType_str, reply_str))
|
||||
if "[DONE]" in reply_str:
|
||||
logger.debug(
|
||||
"%s socket received stop signal request_id: %s", scene,
|
||||
request_id_str)
|
||||
request_queues[request_id_str].put_nowait(
|
||||
(contentType_str, None))
|
||||
[body] = await socket.recv_multipart()
|
||||
response = ZmqMsgResponse.model_validate_json(body)
|
||||
request_id = response.request_id
|
||||
logger.debug("%s socket received result: %s", scene,
|
||||
response.model_dump_json())
|
||||
if request_id in request_queues:
|
||||
request_queues[request_id].put_nowait(response)
|
||||
else:
|
||||
logger.debug(
|
||||
"%s socket received but request_id not found discard: %s",
|
||||
scene, request_id_str)
|
||||
scene, request_id)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("%s handler error: %s", scene, e)
|
||||
@ -167,78 +161,70 @@ async def decode_handler(decode_socket: zmq.asyncio.Socket):
|
||||
|
||||
|
||||
# select a socket and execute task
|
||||
async def execute_task_async(headers: dict, request: dict,
|
||||
async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
|
||||
socket: zmq.asyncio.Socket):
|
||||
try:
|
||||
request_id = headers.get(X_REQUEST_ID_KEY)
|
||||
requestBody = json.dumps(request)
|
||||
logger.info("Sending requestBody: %s", requestBody)
|
||||
socket.send_multipart([request_id.encode(), requestBody.encode()])
|
||||
request_id = zmq_msg_request.request_id
|
||||
requestBody = zmq_msg_request.model_dump_json()
|
||||
logger.debug("Sending requestBody: %s", requestBody)
|
||||
socket.send_multipart([requestBody.encode()])
|
||||
logger.debug("Sent end")
|
||||
queue = request_queues[request_id]
|
||||
while True:
|
||||
logger.debug("Waiting for reply")
|
||||
(contentType,
|
||||
reply) = await asyncio.wait_for(queue.get(), TIME_OUT)
|
||||
logger.debug("Received result: %s, %s", contentType, reply)
|
||||
if reply is None:
|
||||
logger.debug("Received stop signal, request_id: %s",
|
||||
request_id)
|
||||
zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
|
||||
queue.get(), TIME_OUT)
|
||||
logger.debug("Received result: %s",
|
||||
zmq_msg_response.model_dump_json())
|
||||
yield zmq_msg_response
|
||||
if zmq_msg_response.stop:
|
||||
logger.debug("Received stop: %s", zmq_msg_response.stop)
|
||||
break
|
||||
yield (contentType, reply)
|
||||
if contentType == CONTENT_TYPE_JSON:
|
||||
logger.debug("Received %s message, request_id: %s",
|
||||
contentType, request_id)
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(traceback.format_exc())
|
||||
yield (CONTENT_TYPE_ERROR, "System Error")
|
||||
yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
|
||||
finally:
|
||||
logger.debug("request_id: %s, execute_task_async end", request_id)
|
||||
|
||||
|
||||
async def prefill(header: dict,
|
||||
original_request_data: dict) -> Union[JSONResponse, bool]:
|
||||
async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
|
||||
logger.debug("start prefill")
|
||||
generator = execute_task_async(header, original_request_data,
|
||||
app.state.prefill_socket)
|
||||
async for contentType, reply in generator:
|
||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
||||
if contentType == CONTENT_TYPE_ERROR:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
|
||||
async for res in generator:
|
||||
logger.debug("res: %s", res)
|
||||
if res.body_type == "response":
|
||||
return JSONResponse(res.body)
|
||||
return True
|
||||
|
||||
|
||||
async def generate_stream_response(fisrt_reply: str,
|
||||
generator: AsyncGenerator):
|
||||
async def generate_stream_response(
|
||||
fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse]
|
||||
) -> AsyncGenerator[dict, str]:
|
||||
yield fisrt_reply
|
||||
async for _, reply in generator:
|
||||
yield reply
|
||||
async for reply in generator:
|
||||
yield reply.body
|
||||
|
||||
|
||||
async def decode(
|
||||
header: dict,
|
||||
original_request_data: dict) -> Union[JSONResponse, StreamingResponse]:
|
||||
logger.info("start decode")
|
||||
generator = execute_task_async(header, original_request_data,
|
||||
app.state.decode_socket)
|
||||
zmq_msg_request: ZmqMsgRequest
|
||||
) -> Union[JSONResponse, StreamingResponse]:
|
||||
logger.debug("start decode")
|
||||
generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
|
||||
|
||||
async for contentType, reply in generator:
|
||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
||||
if contentType == CONTENT_TYPE_ERROR:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
elif contentType == CONTENT_TYPE_JSON:
|
||||
return JSONResponse(reply)
|
||||
async for res in generator:
|
||||
logger.debug("res: %s", res)
|
||||
if res.body_type == "response":
|
||||
return JSONResponse(res.body)
|
||||
else:
|
||||
return StreamingResponse(generate_stream_response(
|
||||
reply, generator),
|
||||
res.body, generator),
|
||||
media_type=CONTENT_TYPE_STREAM)
|
||||
|
||||
# If the generator is empty, return a default error response
|
||||
logger.error("No response received from generator")
|
||||
return JSONResponse({"error": "No response received from generator"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
def cleanup_request_id(request_id: str):
|
||||
if request_id in request_queues:
|
||||
|
||||
@ -1649,3 +1649,19 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
|
||||
words: Optional[list[TranscriptionWord]] = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
class ZmqMsgRequest(BaseModel):
|
||||
request_id: str
|
||||
type: str
|
||||
body: Union[CompletionRequest]
|
||||
|
||||
|
||||
class ZmqMsgResponse(BaseModel):
|
||||
request_id: str
|
||||
type: str
|
||||
stop: bool = True
|
||||
body_type: Literal["str", "response"] = "str"
|
||||
body: Union[dict, str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@ -1,22 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from http import HTTPStatus
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
ErrorResponse, ZmqMsgRequest,
|
||||
ZmqMsgResponse)
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
@ -26,10 +26,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger('vllm.entrypoints.openai.zmq_server')
|
||||
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_ERROR = "error"
|
||||
CONTENT_TYPE_STREAM = "text/event-stream"
|
||||
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_models: OpenAIServingModels
|
||||
|
||||
@ -72,16 +68,22 @@ async def serve_zmq(arg) -> None:
|
||||
try:
|
||||
logger.debug("zmq Server waiting for request")
|
||||
# get new request from the client
|
||||
identity, request_id, body = await socket.recv_multipart()
|
||||
message_parts = await socket.recv_multipart()
|
||||
logger.debug("received request: %s", message_parts)
|
||||
logger.debug("received len: %d", len(message_parts))
|
||||
identity, body = message_parts[0], message_parts[1]
|
||||
zmq_msg_request = ZmqMsgRequest.model_validate_json(body)
|
||||
# launch request handler coroutine
|
||||
task = asyncio.create_task(
|
||||
worker_routine(identity, request_id, body, socket))
|
||||
worker_routine(identity, zmq_msg_request, socket))
|
||||
running_requests.add(task)
|
||||
task.add_done_callback(running_requests.discard)
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("ZMQError: %s", e)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("Unexpected error: %s", e)
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
@ -153,59 +155,68 @@ async def init_state(
|
||||
)
|
||||
|
||||
|
||||
async def worker_routine(identity: bytes, request_id: bytes, body: bytes,
|
||||
async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
|
||||
socket: zmq.asyncio.Socket):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
body_json = json.loads(body.decode())
|
||||
request_id_str = request_id.decode()
|
||||
logger.debug("receive request: %s from %s, request_id: %s", body_json,
|
||||
identity.decode(), request_id_str)
|
||||
|
||||
completionRequest = CompletionRequest(**body_json)
|
||||
generator = await create_completion(completionRequest, None)
|
||||
content_type_json = CONTENT_TYPE_JSON.encode('utf-8')
|
||||
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = json.loads(generator.model_dump_json())
|
||||
content.update({"status_code": generator.code})
|
||||
logger.debug("send ErrorResponse %s", json.dumps(content))
|
||||
await socket.send_multipart([
|
||||
identity, request_id, content_type_json,
|
||||
json.dumps(content).encode('utf-8')
|
||||
])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
logger.debug("send CompletionResponse %s",
|
||||
json.dumps(generator.model_dump()))
|
||||
await socket.send_multipart([
|
||||
identity, request_id, content_type_json,
|
||||
json.dumps(generator.model_dump()).encode('utf-8')
|
||||
])
|
||||
request_id = zmq_msg_request.request_id
|
||||
logger.debug("receive request: %s from %s, request_id: %s",
|
||||
zmq_msg_request.model_dump_json(), identity.decode(),
|
||||
request_id)
|
||||
if isinstance(zmq_msg_request.body, CompletionRequest):
|
||||
await create_completion(identity, zmq_msg_request, socket)
|
||||
else:
|
||||
async for chunk in generator:
|
||||
logger.debug(
|
||||
"send chunk identity: %s, request_id: %s, chunk: %s",
|
||||
identity.decode(), request_id.decode(), chunk)
|
||||
await socket.send_multipart([
|
||||
identity, request_id, content_type_stream,
|
||||
chunk.encode('utf-8')
|
||||
])
|
||||
logger.error("Error in worker routine: %s request_id: %s",
|
||||
"unsupported request type", request_id)
|
||||
raise Exception("unsupported request type")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in worker routine: %s request_id: %s", e,
|
||||
request_id_str)
|
||||
request_id)
|
||||
logger.error(traceback.format_exc())
|
||||
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
|
||||
logger.debug("send ErrorResponse %s", str(e))
|
||||
await socket.send_multipart([
|
||||
identity, request_id,
|
||||
CONTENT_TYPE_ERROR.encode('utf-8'),
|
||||
str(e).encode('utf-8')
|
||||
identity,
|
||||
ZmqMsgResponse(request_id=request_id,
|
||||
type=zmq_msg_request.type,
|
||||
body={
|
||||
"content": "unsupported request type",
|
||||
"status_code": HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
}).model_dump_json().encode(),
|
||||
])
|
||||
|
||||
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
logger.debug("zmq request post: %s", request)
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
request, raw_request)
|
||||
async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
|
||||
socket: zmq.asyncio.Socket):
|
||||
request: CompletionRequest = zmq_msg_request.body
|
||||
logger.debug("zmq request post: %s", request.model_dump_json())
|
||||
generator = await openai_serving_completion.create_completion(request)
|
||||
logger.debug("zmq request end post")
|
||||
return generator
|
||||
request_id = zmq_msg_request.request_id
|
||||
if isinstance(generator, (ErrorResponse, CompletionResponse)):
|
||||
logger.debug("send response %s", generator.model_dump_json())
|
||||
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
|
||||
type=zmq_msg_request.type,
|
||||
body_type="response")
|
||||
if isinstance(generator, ErrorResponse):
|
||||
zmq_msg_response.body = {
|
||||
"content": generator.model_dump(),
|
||||
"status_code": generator.code
|
||||
}
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
zmq_msg_response.body = {"content": generator.model_dump()}
|
||||
|
||||
await socket.send_multipart(
|
||||
[identity, zmq_msg_response.model_dump_json().encode()])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
|
||||
type=zmq_msg_request.type,
|
||||
body=chunk)
|
||||
if "data: [DONE]" not in chunk:
|
||||
zmq_msg_response.stop = False
|
||||
logger.debug("send chunk identity: %s, request_id: %s, chunk: %s",
|
||||
identity.decode(), request_id, chunk)
|
||||
await socket.send_multipart(
|
||||
[identity,
|
||||
zmq_msg_response.model_dump_json().encode()])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user