Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-03-09 13:47:22 +08:00
parent d35dace985
commit c0b1443345
3 changed files with 17 additions and 14 deletions

View File

@ -198,8 +198,8 @@ async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
async def generate_stream_response(
fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse]
) -> AsyncGenerator[dict, str]:
fisrt_reply: str,
generator: AsyncGenerator[ZmqMsgResponse]) -> AsyncGenerator[str]:
yield fisrt_reply
async for reply in generator:
yield reply.body

View File

@ -1662,6 +1662,6 @@ class ZmqMsgResponse(BaseModel):
type: str
stop: bool = True
body_type: Literal["str", "response"] = "str"
body: Union[dict, str] = None
body: str
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import os
import signal
import traceback
@ -179,10 +180,12 @@ async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
identity,
ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body={
"content": "unsupported request type",
"status_code": HTTPStatus.INTERNAL_SERVER_ERROR
}).model_dump_json().encode(),
body=json.dumps({
"content":
"unsupported request type",
"status_code":
HTTPStatus.INTERNAL_SERVER_ERROR
})).model_dump_json().encode(),
])
@ -195,17 +198,17 @@ async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
request_id = zmq_msg_request.request_id
if isinstance(generator, (ErrorResponse, CompletionResponse)):
logger.debug("send response %s", generator.model_dump_json())
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body_type="response")
if isinstance(generator, ErrorResponse):
zmq_msg_response.body = {
body = json.dumps({
"content": generator.model_dump(),
"status_code": generator.code
}
})
elif isinstance(generator, CompletionResponse):
zmq_msg_response.body = {"content": generator.model_dump()}
body = json.dumps({"content": generator.model_dump()})
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body_type="response",
body=body)
await socket.send_multipart(
[identity, zmq_msg_response.model_dump_json().encode()])
else: