1. fix mypy issue

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-01-08 23:15:23 +08:00
parent 897db7b93d
commit 187f112ccd
2 changed files with 26 additions and 27 deletions

View File

@ -4,7 +4,6 @@ import uvicorn
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from starlette.datastructures import Headers
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
# from fastapi.lifespan import Lifespan # from fastapi.lifespan import Lifespan
@ -24,7 +23,7 @@ logger = init_logger('vllm.entrypoints.connect')
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# create scoket pool with prefill and decode # create socket pool with prefill and decode
logger.info("start create_socket_pool") logger.info("start create_socket_pool")
app.state.zmqctx = zmq.asyncio.Context() app.state.zmqctx = zmq.asyncio.Context()
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx) app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
@ -39,7 +38,7 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
# create async socket pool with num_sockets use ZMQ_DEALER # create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context): async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
sockets = Queue() sockets = Queue()
for i in range(num_sockets): for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER) sock = zmqctx.socket(zmq.DEALER)
@ -50,8 +49,8 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
await sockets.put(sock) await sockets.put(sock)
return sockets return sockets
# select a scoket and execute task # select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list): async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
sock = await sockets.get() sock = await sockets.get()
try: try:
requestBody = json.dumps(request) requestBody = json.dumps(request)

View File

@ -1,12 +1,13 @@
import json
from typing import Optional
import zmq import zmq
import zmq.asyncio import zmq.asyncio
import tempfile import tempfile
import uuid import uuid
import httpx import httpx
import json import json
import traceback
from typing import Optional
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -22,7 +23,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.logger import init_logger from vllm.logger import init_logger
import traceback
prometheus_multiproc_dir: tempfile.TemporaryDirectory prometheus_multiproc_dir: tempfile.TemporaryDirectory
@ -54,7 +54,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
return httpx.Headers(headers_dict) return httpx.Headers(headers_dict)
async def worker_routine(worker_url: str, app: FastAPI, async def worker_routine(worker_url: str, app: FastAPI,
context: zmq.asyncio.Context = None, i: int = 0): context: zmq.asyncio.Context, i: int = 0):
"""Worker routine""" """Worker routine"""
try: try:
# Socket to talk to dispatcher # Socket to talk to dispatcher
@ -65,46 +65,46 @@ async def worker_routine(worker_url: str, app: FastAPI,
logger.info(f"{worker_identity} started at {worker_url}") logger.info(f"{worker_identity} started at {worker_url}")
while True: while True:
identity, url, header, body = await socket.recv_multipart() identity, url, header, body = await socket.recv_multipart()
logger.info(f"worker-{i} Received request identity: [{identity} ]") logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]")
url = url.decode() url_str = url.decode()
logger.info(f"worker-{i} Received request url: [{url} ]") logger.info(f"worker-{i} Received request url: [{url_str} ]")
header = bytes_to_headers(header) headers = bytes_to_headers(header)
logger.info(f"worker-{i} Received request headers: [{header} ]") logger.info(f"worker-{i} Received request headers: [{headers} ]")
body = json.loads(body.decode()) body_json = json.loads(body.decode())
logger.info(f"worker-{i} Received request body: [{body} ]") logger.info(f"worker-{i} Received request body: [{body_json} ]")
logger.info(f"worker-{i} Calling OpenAI API") logger.info(f"worker-{i} Calling OpenAI API")
completionRequest = CompletionRequest(**body) completionRequest = CompletionRequest(**body_json)
createRequest = create_request(url, "POST", body, header) createRequest = create_request(url_str, "POST", body_json, headers)
generator = await create_completion(app, completionRequest, createRequest) generator = await create_completion(app, completionRequest, createRequest)
logger.info(f"worker-{i} Received response: [{generator} ]") logger.info(f"worker-{i} Received response: [{generator} ]")
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
content = generator.model_dump_json() content = generator.model_dump_json()
context = json.loads(content) context_json = json.loads(content)
context.append("status_code", generator.code) context_json.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()]) await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')])
elif isinstance(generator, CompletionResponse): elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())]) await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')])
else: else:
async for chunk in generator: async for chunk in generator:
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]") logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()]) await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')])
except Exception as e: except Exception as e:
logger.error(f"Error in worker routine: {e} worker-{i}") logger.error(f"Error in worker routine: {e} worker-{i}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request): async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
handler = completion(app) handler = completion(app)
logger.info(f"zmq requset post: {request}") logger.info(f"zmq request post: {request}")
if handler is None: if handler is None:
return base(app).create_error_response( return base(app).create_error_response(
message="The model does not support Completions API") message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request) generator = await handler.create_completion(request, raw_request)
logger.info(f"zmq requset end post: {generator}") logger.info(f"zmq request end post: {generator}")
return generator return generator
def create_request(path: str, method: str, body: bytes, headers: dict = None): def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request:
scope = { scope = {
'type': 'http', 'type': 'http',
'http_version': '1.1', 'http_version': '1.1',
@ -113,7 +113,7 @@ def create_request(path: str, method: str, body: bytes, headers: dict = None):
'headers': list(headers.items()) if headers else [], 'headers': list(headers.items()) if headers else [],
} }
if body: if body:
scope['body'] = json.dumps(body).encode('utf-8') scope['body'] = json.dumps(body)
async def receive(): async def receive():
return { return {
'type': 'http.request', 'type': 'http.request',