mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 06:17:51 +08:00
1. fix mypy issue
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
897db7b93d
commit
187f112ccd
@ -4,7 +4,6 @@ import uvicorn
|
|||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from starlette.datastructures import Headers
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
# from fastapi.lifespan import Lifespan
|
# from fastapi.lifespan import Lifespan
|
||||||
@ -24,7 +23,7 @@ logger = init_logger('vllm.entrypoints.connect')
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# create scoket pool with prefill and decode
|
# create socket pool with prefill and decode
|
||||||
logger.info("start create_socket_pool")
|
logger.info("start create_socket_pool")
|
||||||
app.state.zmqctx = zmq.asyncio.Context()
|
app.state.zmqctx = zmq.asyncio.Context()
|
||||||
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
||||||
@ -39,7 +38,7 @@ async def lifespan(app: FastAPI):
|
|||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# create async socket pool with num_sockets use ZMQ_DEALER
|
# create async socket pool with num_sockets use ZMQ_DEALER
|
||||||
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context):
|
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
|
||||||
sockets = Queue()
|
sockets = Queue()
|
||||||
for i in range(num_sockets):
|
for i in range(num_sockets):
|
||||||
sock = zmqctx.socket(zmq.DEALER)
|
sock = zmqctx.socket(zmq.DEALER)
|
||||||
@ -50,8 +49,8 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
|
|||||||
await sockets.put(sock)
|
await sockets.put(sock)
|
||||||
return sockets
|
return sockets
|
||||||
|
|
||||||
# select a scoket and execute task
|
# select a socket and execute task
|
||||||
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
|
async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
|
||||||
sock = await sockets.get()
|
sock = await sockets.get()
|
||||||
try:
|
try:
|
||||||
requestBody = json.dumps(request)
|
requestBody = json.dumps(request)
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import json
|
|
||||||
from typing import Optional
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
@ -22,7 +23,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
import traceback
|
|
||||||
|
|
||||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
|||||||
return httpx.Headers(headers_dict)
|
return httpx.Headers(headers_dict)
|
||||||
|
|
||||||
async def worker_routine(worker_url: str, app: FastAPI,
|
async def worker_routine(worker_url: str, app: FastAPI,
|
||||||
context: zmq.asyncio.Context = None, i: int = 0):
|
context: zmq.asyncio.Context, i: int = 0):
|
||||||
"""Worker routine"""
|
"""Worker routine"""
|
||||||
try:
|
try:
|
||||||
# Socket to talk to dispatcher
|
# Socket to talk to dispatcher
|
||||||
@ -65,46 +65,46 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
|||||||
logger.info(f"{worker_identity} started at {worker_url}")
|
logger.info(f"{worker_identity} started at {worker_url}")
|
||||||
while True:
|
while True:
|
||||||
identity, url, header, body = await socket.recv_multipart()
|
identity, url, header, body = await socket.recv_multipart()
|
||||||
logger.info(f"worker-{i} Received request identity: [{identity} ]")
|
logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]")
|
||||||
url = url.decode()
|
url_str = url.decode()
|
||||||
logger.info(f"worker-{i} Received request url: [{url} ]")
|
logger.info(f"worker-{i} Received request url: [{url_str} ]")
|
||||||
header = bytes_to_headers(header)
|
headers = bytes_to_headers(header)
|
||||||
logger.info(f"worker-{i} Received request headers: [{header} ]")
|
logger.info(f"worker-{i} Received request headers: [{headers} ]")
|
||||||
body = json.loads(body.decode())
|
body_json = json.loads(body.decode())
|
||||||
logger.info(f"worker-{i} Received request body: [{body} ]")
|
logger.info(f"worker-{i} Received request body: [{body_json} ]")
|
||||||
logger.info(f"worker-{i} Calling OpenAI API")
|
logger.info(f"worker-{i} Calling OpenAI API")
|
||||||
completionRequest = CompletionRequest(**body)
|
completionRequest = CompletionRequest(**body_json)
|
||||||
createRequest = create_request(url, "POST", body, header)
|
createRequest = create_request(url_str, "POST", body_json, headers)
|
||||||
generator = await create_completion(app, completionRequest, createRequest)
|
generator = await create_completion(app, completionRequest, createRequest)
|
||||||
logger.info(f"worker-{i} Received response: [{generator} ]")
|
logger.info(f"worker-{i} Received response: [{generator} ]")
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
content = generator.model_dump_json()
|
content = generator.model_dump_json()
|
||||||
context = json.loads(content)
|
context_json = json.loads(content)
|
||||||
context.append("status_code", generator.code)
|
context_json.append("status_code", generator.code)
|
||||||
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
|
await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')])
|
||||||
elif isinstance(generator, CompletionResponse):
|
elif isinstance(generator, CompletionResponse):
|
||||||
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
|
await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')])
|
||||||
else:
|
else:
|
||||||
async for chunk in generator:
|
async for chunk in generator:
|
||||||
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
|
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
|
||||||
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
|
await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in worker routine: {e} worker-{i}")
|
logger.error(f"Error in worker routine: {e} worker-{i}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
|
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
|
||||||
handler = completion(app)
|
handler = completion(app)
|
||||||
logger.info(f"zmq requset post: {request}")
|
logger.info(f"zmq request post: {request}")
|
||||||
if handler is None:
|
if handler is None:
|
||||||
return base(app).create_error_response(
|
return base(app).create_error_response(
|
||||||
message="The model does not support Completions API")
|
message="The model does not support Completions API")
|
||||||
|
|
||||||
generator = await handler.create_completion(request, raw_request)
|
generator = await handler.create_completion(request, raw_request)
|
||||||
logger.info(f"zmq requset end post: {generator}")
|
logger.info(f"zmq request end post: {generator}")
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
|
|
||||||
def create_request(path: str, method: str, body: bytes, headers: dict = None):
|
def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request:
|
||||||
scope = {
|
scope = {
|
||||||
'type': 'http',
|
'type': 'http',
|
||||||
'http_version': '1.1',
|
'http_version': '1.1',
|
||||||
@ -113,7 +113,7 @@ def create_request(path: str, method: str, body: bytes, headers: dict = None):
|
|||||||
'headers': list(headers.items()) if headers else [],
|
'headers': list(headers.items()) if headers else [],
|
||||||
}
|
}
|
||||||
if body:
|
if body:
|
||||||
scope['body'] = json.dumps(body).encode('utf-8')
|
scope['body'] = json.dumps(body)
|
||||||
async def receive():
|
async def receive():
|
||||||
return {
|
return {
|
||||||
'type': 'http.request',
|
'type': 'http.request',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user