mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 00:31:19 +08:00
refactor zmq msg to object
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
912031ceb5
commit
d35dace985
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -17,9 +17,8 @@ import zmq.asyncio
|
|||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import BackgroundTasks, FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR,
|
from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
|
||||||
CONTENT_TYPE_JSON,
|
ZmqMsgResponse)
|
||||||
CONTENT_TYPE_STREAM)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -28,6 +27,7 @@ logger = init_logger('vllm.entrypoints.disagg_connector')
|
|||||||
|
|
||||||
TIME_OUT = 5
|
TIME_OUT = 5
|
||||||
X_REQUEST_ID_KEY = "X-Request-Id"
|
X_REQUEST_ID_KEY = "X-Request-Id"
|
||||||
|
CONTENT_TYPE_STREAM = "text/event-stream"
|
||||||
|
|
||||||
# communication between output handlers and execute_task_async
|
# communication between output handlers and execute_task_async
|
||||||
request_queues: dict[str, asyncio.Queue]
|
request_queues: dict[str, asyncio.Queue]
|
||||||
@ -77,40 +77,44 @@ app = FastAPI(lifespan=lifespan)
|
|||||||
|
|
||||||
|
|
||||||
@app.post('/v1/completions')
|
@app.post('/v1/completions')
|
||||||
async def completions(request: Request, background_tasks: BackgroundTasks):
|
async def completions(request: CompletionRequest, raw_request: Request,
|
||||||
|
background_tasks: BackgroundTasks):
|
||||||
try:
|
try:
|
||||||
# Add the X-Request-Id header to the raw headers list
|
# Add the X-Request-Id header to the raw headers list
|
||||||
header = dict(request.headers)
|
header = dict(raw_request.headers)
|
||||||
request_id = header.get(X_REQUEST_ID_KEY)
|
request_id = header.get(X_REQUEST_ID_KEY)
|
||||||
queue = asyncio.Queue()
|
queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
logger.debug("add X-Request-Id: %s", request_id)
|
logger.debug("add X-Request-Id: %s", request_id)
|
||||||
header[X_REQUEST_ID_KEY] = request_id
|
|
||||||
logger.debug("X-Request-Id is: %s", request_id)
|
logger.debug("X-Request-Id is: %s", request_id)
|
||||||
request_queues[request_id] = queue
|
request_queues[request_id] = queue
|
||||||
request_data = await request.json()
|
zmq_msg_request = ZmqMsgRequest(request_id=request_id,
|
||||||
|
type="completions",
|
||||||
|
body=request)
|
||||||
logger.info("Received request_id: %s, request: %s, header: %s",
|
logger.info("Received request_id: %s, request: %s, header: %s",
|
||||||
request_id, request_data, header)
|
request_id, zmq_msg_request.model_dump_json(), header)
|
||||||
original_max_tokens = request_data['max_tokens']
|
original_max_tokens = request.max_tokens
|
||||||
# change max_tokens = 1 to let it only do prefill
|
# change max_tokens = 1 to let it only do prefill
|
||||||
request_data['max_tokens'] = 1
|
request.max_tokens = 1
|
||||||
# finish prefill
|
# finish prefill
|
||||||
try:
|
try:
|
||||||
prefill_response = await prefill(header, request_data)
|
prefill_response = await prefill(zmq_msg_request)
|
||||||
if isinstance(prefill_response, JSONResponse):
|
if isinstance(prefill_response, JSONResponse
|
||||||
|
) and prefill_response.status_code != HTTPStatus.OK:
|
||||||
return prefill_response
|
return prefill_response
|
||||||
logger.debug("finish prefill start decode")
|
logger.debug("finish prefill start decode")
|
||||||
request_data['max_tokens'] = original_max_tokens
|
request.max_tokens = original_max_tokens
|
||||||
response = await decode(header, request_data)
|
response = await decode(zmq_msg_request)
|
||||||
logger.debug("finish decode")
|
logger.debug("finish decode")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
logger.error("Error occurred in disagg prefill proxy server, %s",
|
||||||
e)
|
e)
|
||||||
response = JSONResponse({"error": {
|
response = JSONResponse(
|
||||||
"message": str(e)
|
{"error": {
|
||||||
}},
|
"message": str(e)
|
||||||
status_code=500)
|
}},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -120,37 +124,27 @@ async def completions(request: Request, background_tasks: BackgroundTasks):
|
|||||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||||
response = JSONResponse({"error": {
|
response = JSONResponse({"error": {
|
||||||
"message": str(e)
|
"message": str(e)
|
||||||
}},
|
}}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
status_code=500)
|
|
||||||
return response
|
return response
|
||||||
finally:
|
finally:
|
||||||
background_tasks.add_task(cleanup_request_id, request_id)
|
if request_id is not None:
|
||||||
|
background_tasks.add_task(cleanup_request_id, request_id)
|
||||||
|
|
||||||
|
|
||||||
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
|
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
[request_id, contentType, reply] = await socket.recv_multipart()
|
[body] = await socket.recv_multipart()
|
||||||
contentType_str = contentType.decode()
|
response = ZmqMsgResponse.model_validate_json(body)
|
||||||
reply_str = reply.decode()
|
request_id = response.request_id
|
||||||
request_id_str = request_id.decode()
|
logger.debug("%s socket received result: %s", scene,
|
||||||
logger.debug(
|
response.model_dump_json())
|
||||||
"%s socket received result contentType: %s, "
|
if request_id in request_queues:
|
||||||
"request_id: %s, reply: %s", scene, contentType_str,
|
request_queues[request_id].put_nowait(response)
|
||||||
request_id_str, reply_str)
|
|
||||||
if request_id_str in request_queues:
|
|
||||||
request_queues[request_id_str].put_nowait(
|
|
||||||
(contentType_str, reply_str))
|
|
||||||
if "[DONE]" in reply_str:
|
|
||||||
logger.debug(
|
|
||||||
"%s socket received stop signal request_id: %s", scene,
|
|
||||||
request_id_str)
|
|
||||||
request_queues[request_id_str].put_nowait(
|
|
||||||
(contentType_str, None))
|
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s socket received but request_id not found discard: %s",
|
"%s socket received but request_id not found discard: %s",
|
||||||
scene, request_id_str)
|
scene, request_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error("%s handler error: %s", scene, e)
|
logger.error("%s handler error: %s", scene, e)
|
||||||
@ -167,78 +161,70 @@ async def decode_handler(decode_socket: zmq.asyncio.Socket):
|
|||||||
|
|
||||||
|
|
||||||
# select a socket and execute task
|
# select a socket and execute task
|
||||||
async def execute_task_async(headers: dict, request: dict,
|
async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
|
||||||
socket: zmq.asyncio.Socket):
|
socket: zmq.asyncio.Socket):
|
||||||
try:
|
try:
|
||||||
request_id = headers.get(X_REQUEST_ID_KEY)
|
request_id = zmq_msg_request.request_id
|
||||||
requestBody = json.dumps(request)
|
requestBody = zmq_msg_request.model_dump_json()
|
||||||
logger.info("Sending requestBody: %s", requestBody)
|
logger.debug("Sending requestBody: %s", requestBody)
|
||||||
socket.send_multipart([request_id.encode(), requestBody.encode()])
|
socket.send_multipart([requestBody.encode()])
|
||||||
logger.debug("Sent end")
|
logger.debug("Sent end")
|
||||||
queue = request_queues[request_id]
|
queue = request_queues[request_id]
|
||||||
while True:
|
while True:
|
||||||
logger.debug("Waiting for reply")
|
logger.debug("Waiting for reply")
|
||||||
(contentType,
|
zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
|
||||||
reply) = await asyncio.wait_for(queue.get(), TIME_OUT)
|
queue.get(), TIME_OUT)
|
||||||
logger.debug("Received result: %s, %s", contentType, reply)
|
logger.debug("Received result: %s",
|
||||||
if reply is None:
|
zmq_msg_response.model_dump_json())
|
||||||
logger.debug("Received stop signal, request_id: %s",
|
yield zmq_msg_response
|
||||||
request_id)
|
if zmq_msg_response.stop:
|
||||||
|
logger.debug("Received stop: %s", zmq_msg_response.stop)
|
||||||
break
|
break
|
||||||
yield (contentType, reply)
|
|
||||||
if contentType == CONTENT_TYPE_JSON:
|
|
||||||
logger.debug("Received %s message, request_id: %s",
|
|
||||||
contentType, request_id)
|
|
||||||
break
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
yield (CONTENT_TYPE_ERROR, "System Error")
|
yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
|
||||||
finally:
|
finally:
|
||||||
logger.debug("request_id: %s, execute_task_async end", request_id)
|
logger.debug("request_id: %s, execute_task_async end", request_id)
|
||||||
|
|
||||||
|
|
||||||
async def prefill(header: dict,
|
async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
|
||||||
original_request_data: dict) -> Union[JSONResponse, bool]:
|
|
||||||
logger.debug("start prefill")
|
logger.debug("start prefill")
|
||||||
generator = execute_task_async(header, original_request_data,
|
generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
|
||||||
app.state.prefill_socket)
|
async for res in generator:
|
||||||
async for contentType, reply in generator:
|
logger.debug("res: %s", res)
|
||||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
if res.body_type == "response":
|
||||||
if contentType == CONTENT_TYPE_ERROR:
|
return JSONResponse(res.body)
|
||||||
response = JSONResponse({"error": reply})
|
|
||||||
response.status_code = 500
|
|
||||||
return response
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def generate_stream_response(fisrt_reply: str,
|
async def generate_stream_response(
|
||||||
generator: AsyncGenerator):
|
fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse]
|
||||||
|
) -> AsyncGenerator[dict, str]:
|
||||||
yield fisrt_reply
|
yield fisrt_reply
|
||||||
async for _, reply in generator:
|
async for reply in generator:
|
||||||
yield reply
|
yield reply.body
|
||||||
|
|
||||||
|
|
||||||
async def decode(
|
async def decode(
|
||||||
header: dict,
|
zmq_msg_request: ZmqMsgRequest
|
||||||
original_request_data: dict) -> Union[JSONResponse, StreamingResponse]:
|
) -> Union[JSONResponse, StreamingResponse]:
|
||||||
logger.info("start decode")
|
logger.debug("start decode")
|
||||||
generator = execute_task_async(header, original_request_data,
|
generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
|
||||||
app.state.decode_socket)
|
|
||||||
|
|
||||||
async for contentType, reply in generator:
|
async for res in generator:
|
||||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
logger.debug("res: %s", res)
|
||||||
if contentType == CONTENT_TYPE_ERROR:
|
if res.body_type == "response":
|
||||||
response = JSONResponse({"error": reply})
|
return JSONResponse(res.body)
|
||||||
response.status_code = 500
|
|
||||||
return response
|
|
||||||
elif contentType == CONTENT_TYPE_JSON:
|
|
||||||
return JSONResponse(reply)
|
|
||||||
else:
|
else:
|
||||||
return StreamingResponse(generate_stream_response(
|
return StreamingResponse(generate_stream_response(
|
||||||
reply, generator),
|
res.body, generator),
|
||||||
media_type=CONTENT_TYPE_STREAM)
|
media_type=CONTENT_TYPE_STREAM)
|
||||||
|
|
||||||
|
# If the generator is empty, return a default error response
|
||||||
|
logger.error("No response received from generator")
|
||||||
|
return JSONResponse({"error": "No response received from generator"},
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_request_id(request_id: str):
|
def cleanup_request_id(request_id: str):
|
||||||
if request_id in request_queues:
|
if request_id in request_queues:
|
||||||
|
|||||||
@ -1649,3 +1649,19 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
|||||||
|
|
||||||
words: Optional[list[TranscriptionWord]] = None
|
words: Optional[list[TranscriptionWord]] = None
|
||||||
"""Extracted words and their corresponding timestamps."""
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqMsgRequest(BaseModel):
|
||||||
|
request_id: str
|
||||||
|
type: str
|
||||||
|
body: Union[CompletionRequest]
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqMsgResponse(BaseModel):
|
||||||
|
request_id: str
|
||||||
|
type: str
|
||||||
|
stop: bool = True
|
||||||
|
body_type: Literal["str", "response"] = "str"
|
||||||
|
body: Union[dict, str] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|||||||
@ -1,22 +1,22 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import traceback
|
import traceback
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from fastapi import Request
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
ErrorResponse)
|
ErrorResponse, ZmqMsgRequest,
|
||||||
|
ZmqMsgResponse)
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
OpenAIServingModels)
|
OpenAIServingModels)
|
||||||
@ -26,10 +26,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
|||||||
|
|
||||||
logger = init_logger('vllm.entrypoints.openai.zmq_server')
|
logger = init_logger('vllm.entrypoints.openai.zmq_server')
|
||||||
|
|
||||||
CONTENT_TYPE_JSON = "application/json"
|
|
||||||
CONTENT_TYPE_ERROR = "error"
|
|
||||||
CONTENT_TYPE_STREAM = "text/event-stream"
|
|
||||||
|
|
||||||
openai_serving_completion: OpenAIServingCompletion
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
openai_serving_models: OpenAIServingModels
|
openai_serving_models: OpenAIServingModels
|
||||||
|
|
||||||
@ -72,16 +68,22 @@ async def serve_zmq(arg) -> None:
|
|||||||
try:
|
try:
|
||||||
logger.debug("zmq Server waiting for request")
|
logger.debug("zmq Server waiting for request")
|
||||||
# get new request from the client
|
# get new request from the client
|
||||||
identity, request_id, body = await socket.recv_multipart()
|
message_parts = await socket.recv_multipart()
|
||||||
|
logger.debug("received request: %s", message_parts)
|
||||||
|
logger.debug("received len: %d", len(message_parts))
|
||||||
|
identity, body = message_parts[0], message_parts[1]
|
||||||
|
zmq_msg_request = ZmqMsgRequest.model_validate_json(body)
|
||||||
# launch request handler coroutine
|
# launch request handler coroutine
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
worker_routine(identity, request_id, body, socket))
|
worker_routine(identity, zmq_msg_request, socket))
|
||||||
running_requests.add(task)
|
running_requests.add(task)
|
||||||
task.add_done_callback(running_requests.discard)
|
task.add_done_callback(running_requests.discard)
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
logger.error("ZMQError: %s", e)
|
logger.error("ZMQError: %s", e)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
logger.error("Unexpected error: %s", e)
|
logger.error("Unexpected error: %s", e)
|
||||||
break
|
break
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -153,59 +155,68 @@ async def init_state(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def worker_routine(identity: bytes, request_id: bytes, body: bytes,
|
async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
|
||||||
socket: zmq.asyncio.Socket):
|
socket: zmq.asyncio.Socket):
|
||||||
"""Worker routine"""
|
"""Worker routine"""
|
||||||
try:
|
try:
|
||||||
body_json = json.loads(body.decode())
|
request_id = zmq_msg_request.request_id
|
||||||
request_id_str = request_id.decode()
|
logger.debug("receive request: %s from %s, request_id: %s",
|
||||||
logger.debug("receive request: %s from %s, request_id: %s", body_json,
|
zmq_msg_request.model_dump_json(), identity.decode(),
|
||||||
identity.decode(), request_id_str)
|
request_id)
|
||||||
|
if isinstance(zmq_msg_request.body, CompletionRequest):
|
||||||
completionRequest = CompletionRequest(**body_json)
|
await create_completion(identity, zmq_msg_request, socket)
|
||||||
generator = await create_completion(completionRequest, None)
|
|
||||||
content_type_json = CONTENT_TYPE_JSON.encode('utf-8')
|
|
||||||
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
content = json.loads(generator.model_dump_json())
|
|
||||||
content.update({"status_code": generator.code})
|
|
||||||
logger.debug("send ErrorResponse %s", json.dumps(content))
|
|
||||||
await socket.send_multipart([
|
|
||||||
identity, request_id, content_type_json,
|
|
||||||
json.dumps(content).encode('utf-8')
|
|
||||||
])
|
|
||||||
elif isinstance(generator, CompletionResponse):
|
|
||||||
logger.debug("send CompletionResponse %s",
|
|
||||||
json.dumps(generator.model_dump()))
|
|
||||||
await socket.send_multipart([
|
|
||||||
identity, request_id, content_type_json,
|
|
||||||
json.dumps(generator.model_dump()).encode('utf-8')
|
|
||||||
])
|
|
||||||
else:
|
else:
|
||||||
async for chunk in generator:
|
logger.error("Error in worker routine: %s request_id: %s",
|
||||||
logger.debug(
|
"unsupported request type", request_id)
|
||||||
"send chunk identity: %s, request_id: %s, chunk: %s",
|
raise Exception("unsupported request type")
|
||||||
identity.decode(), request_id.decode(), chunk)
|
|
||||||
await socket.send_multipart([
|
|
||||||
identity, request_id, content_type_stream,
|
|
||||||
chunk.encode('utf-8')
|
|
||||||
])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in worker routine: %s request_id: %s", e,
|
logger.error("Error in worker routine: %s request_id: %s", e,
|
||||||
request_id_str)
|
request_id)
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
|
|
||||||
logger.debug("send ErrorResponse %s", str(e))
|
logger.debug("send ErrorResponse %s", str(e))
|
||||||
await socket.send_multipart([
|
await socket.send_multipart([
|
||||||
identity, request_id,
|
identity,
|
||||||
CONTENT_TYPE_ERROR.encode('utf-8'),
|
ZmqMsgResponse(request_id=request_id,
|
||||||
str(e).encode('utf-8')
|
type=zmq_msg_request.type,
|
||||||
|
body={
|
||||||
|
"content": "unsupported request type",
|
||||||
|
"status_code": HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
}).model_dump_json().encode(),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
|
||||||
logger.debug("zmq request post: %s", request)
|
socket: zmq.asyncio.Socket):
|
||||||
generator = await openai_serving_completion.create_completion(
|
request: CompletionRequest = zmq_msg_request.body
|
||||||
request, raw_request)
|
logger.debug("zmq request post: %s", request.model_dump_json())
|
||||||
|
generator = await openai_serving_completion.create_completion(request)
|
||||||
logger.debug("zmq request end post")
|
logger.debug("zmq request end post")
|
||||||
return generator
|
request_id = zmq_msg_request.request_id
|
||||||
|
if isinstance(generator, (ErrorResponse, CompletionResponse)):
|
||||||
|
logger.debug("send response %s", generator.model_dump_json())
|
||||||
|
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
|
||||||
|
type=zmq_msg_request.type,
|
||||||
|
body_type="response")
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
zmq_msg_response.body = {
|
||||||
|
"content": generator.model_dump(),
|
||||||
|
"status_code": generator.code
|
||||||
|
}
|
||||||
|
elif isinstance(generator, CompletionResponse):
|
||||||
|
zmq_msg_response.body = {"content": generator.model_dump()}
|
||||||
|
|
||||||
|
await socket.send_multipart(
|
||||||
|
[identity, zmq_msg_response.model_dump_json().encode()])
|
||||||
|
else:
|
||||||
|
async for chunk in generator:
|
||||||
|
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
|
||||||
|
type=zmq_msg_request.type,
|
||||||
|
body=chunk)
|
||||||
|
if "data: [DONE]" not in chunk:
|
||||||
|
zmq_msg_response.stop = False
|
||||||
|
logger.debug("send chunk identity: %s, request_id: %s, chunk: %s",
|
||||||
|
identity.decode(), request_id, chunk)
|
||||||
|
await socket.send_multipart(
|
||||||
|
[identity,
|
||||||
|
zmq_msg_response.model_dump_json().encode()])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user