From c0b14433452131ce5ae9ebcec142c10b9f377489 Mon Sep 17 00:00:00 2001 From: clark Date: Sun, 9 Mar 2025 13:47:22 +0800 Subject: [PATCH] fix mypy Signed-off-by: clark --- vllm/entrypoints/disagg_connector.py | 4 ++-- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/zmq_server.py | 25 ++++++++++++++----------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index 245188c3cf72d..ba0b0d748ad89 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -198,8 +198,8 @@ async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]: async def generate_stream_response( - fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse] -) -> AsyncGenerator[dict, str]: + fisrt_reply: str, + generator: AsyncGenerator[ZmqMsgResponse]) -> AsyncGenerator[str]: yield fisrt_reply async for reply in generator: yield reply.body diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a91cae3a34a7c..23525d80995b6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1662,6 +1662,6 @@ class ZmqMsgResponse(BaseModel): type: str stop: bool = True body_type: Literal["str", "response"] = "str" - body: Union[dict, str] = None + body: str model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/vllm/entrypoints/openai/zmq_server.py b/vllm/entrypoints/openai/zmq_server.py index b1ccd20e831e3..caba77e3a7462 100644 --- a/vllm/entrypoints/openai/zmq_server.py +++ b/vllm/entrypoints/openai/zmq_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import os import signal import traceback @@ -179,10 +180,12 @@ async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest, identity, ZmqMsgResponse(request_id=request_id, type=zmq_msg_request.type, - body={ - "content": "unsupported request type", - "status_code": HTTPStatus.INTERNAL_SERVER_ERROR - }).model_dump_json().encode(), + body=json.dumps({ + "content": + "unsupported request type", + "status_code": + HTTPStatus.INTERNAL_SERVER_ERROR + })).model_dump_json().encode(), ]) @@ -195,17 +198,17 @@ async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest, request_id = zmq_msg_request.request_id if isinstance(generator, (ErrorResponse, CompletionResponse)): logger.debug("send response %s", generator.model_dump_json()) - zmq_msg_response = ZmqMsgResponse(request_id=request_id, - type=zmq_msg_request.type, - body_type="response") if isinstance(generator, ErrorResponse): - zmq_msg_response.body = { + body = json.dumps({ "content": generator.model_dump(), "status_code": generator.code - } + }) elif isinstance(generator, CompletionResponse): - zmq_msg_response.body = {"content": generator.model_dump()} - + body = json.dumps({"content": generator.model_dump()}) + zmq_msg_response = ZmqMsgResponse(request_id=request_id, + type=zmq_msg_request.type, + body_type="response", + body=body) await socket.send_multipart( [identity, zmq_msg_response.model_dump_json().encode()]) else: