Run yapf and ruff

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-01-09 00:12:58 +08:00
parent 187f112ccd
commit 2c31e4c3ea
5 changed files with 140 additions and 91 deletions

View File

@ -1,31 +1,34 @@
import asyncio
import json
import aiohttp
# test connect completions we assume prefill and decode are on the same node
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 --chat-template ~/vllm/examples/template_chatglm2.jinja
# test connect completions we assume prefill and decode are on the same node
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
# --chat-template ~/vllm/examples/template_chatglm2.jinja
# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010
# 3. python test_request.py
async def test_connect_completions(session):
try:
base_url = "http://localhost:8001/v1/connect/completions"
body = {
"temperature": 0.5,
"top_p": 0.9,
"max_tokens": 150,
"frequency_penalty": 1.3,
"presence_penalty": 0.2,
"repetition_penalty": 1.2,
"model": "facebook/opt-125m",
"prompt": "Can you introduce vllm?",
"stream": True,
"stream_options": {
"temperature": 0.5,
"top_p": 0.9,
"max_tokens": 150,
"frequency_penalty": 1.3,
"presence_penalty": 0.2,
"repetition_penalty": 1.2,
"model": "facebook/opt-125m",
"prompt": "Can you introduce vllm?",
"stream": True,
"stream_options": {
"include_usage": True
}}
print(f"Sending request to {base_url}, body {body}")
async with session.post(base_url, json= body) as response:
}
}
print(f"Sending request to {base_url}, body {body}")
async with session.post(base_url, json=body) as response:
print(response.status)
print(response.headers)
responseText = ""
@ -40,13 +43,18 @@ async def test_connect_completions(session):
print(f"Error decoding chunk: {chunk!r}")
else:
# Print the headers and JSON response
print(f"Unexpected Transfer-Encoding: {transfer_encoding} {response.headers} {await response.json()}")
print("Unexpected Transfer-Encoding: {} {} {}".format(
transfer_encoding, response.headers, await
response.json()))
else:
print(f"Request failed with status code {response.status}")
print(f"baseurl {base_url} response data {extract_data(responseText)}")
print(
f"baseurl {base_url} response data {extract_data(responseText)}"
)
except aiohttp.ClientError as e:
print(f"Error: {e}")
def extract_data(responseText):
reply = ""
for data in responseText.split("\n\n"):
@ -66,7 +74,7 @@ def extract_data(responseText):
return reply
async def main():
async with aiohttp.ClientSession() as session:
tasks = []
@ -76,4 +84,3 @@ async def main():
asyncio.run(main())

View File

@ -1,15 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
import json
import signal
import uuid
# from fastapi.lifespan import Lifespan
from asyncio import Queue
from contextlib import asynccontextmanager
import uvicorn
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
# from fastapi.lifespan import Lifespan
from asyncio import Queue
import uuid
import signal
from vllm.logger import init_logger
# default prefill and decode url
@ -21,55 +24,69 @@ socket_decode_num = 5
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.connect')
@asynccontextmanager
async def lifespan(app: FastAPI):
# create socket pool with prefill and decode
logger.info("start create_socket_pool")
app.state.zmqctx = zmq.asyncio.Context()
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
app.state.sockets_prefill = await create_socket_pool(
app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_prefill")
app.state.sockets_decode = await create_socket_pool(app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
app.state.sockets_decode = await create_socket_pool(
app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_decode")
yield
## close zmq context
logger.info("term zmqctx")
app.state.zmqctx.destroy(linger=0)
app = FastAPI(lifespan=lifespan)
# create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
sockets = Queue()
async def create_socket_pool(url: str, num_sockets: int,
zmqctx: zmq.asyncio.Context) -> Queue:
sockets: Queue = Queue()
for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER)
identity = f"worker-{i}-{uuid.uuid4()}"
sock.setsockopt(zmq.IDENTITY, identity.encode())
sock.connect(url)
logger.info(f"{identity} started at {url} {sockets.qsize()}")
logger.info("%s started at %s with queue size %s", identity, url,
sockets.qsize())
await sockets.put(sock)
return sockets
# select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
async def execute_task_async(route: str, headers: dict, request: dict,
sockets: Queue):
sock = await sockets.get()
try:
requestBody = json.dumps(request)
headersJson = json.dumps(headers)
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
logger.info(f"Sent end")
logger.info("Sending requestBody: %s to %s with headers: %s",
requestBody, route, headersJson)
await sock.send_multipart(
[route.encode(),
headersJson.encode(),
requestBody.encode()])
logger.info("Sent end")
while True:
logger.info(f"Waiting for reply")
logger.info("Waiting for reply")
[contentType, reply] = await sock.recv_multipart()
logger.info(f"Received result: {contentType}, {reply}")
logger.info("Received result: %s, %s", contentType, reply)
reply = reply.decode()
yield f"{reply}"
if "[DONE]" in reply:
logger.info(f"Received stop signal, return socket")
logger.info("Received stop signal, return socket")
break
finally:
await sockets.put(sock)
@app.post('/v1/connect/completions')
async def chat_completions(request: Request):
try:
@ -77,21 +94,26 @@ async def chat_completions(request: Request):
x_request_id = str(uuid.uuid4())
header = dict(request.headers)
if header.get("X-Request-Id") is None:
logger.info(f"add X-Request-Id: {x_request_id}")
logger.info("add X-Request-Id: %s", x_request_id)
header["X-Request-Id"] = x_request_id
original_request_data = await request.json()
logger.info(f"Received request: {original_request_data} header: {header}")
logger.info("Received request: %s header: %s", original_request_data,
header)
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
route = "/v1/completions"
# finish prefill
async for _ in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
async for _ in execute_task_async(route, header, prefill_request,
app.state.sockets_prefill):
continue
# return decode
return StreamingResponse(execute_task_async(route, header,original_request_data, app.state.sockets_decode), media_type="text/event-stream")
return StreamingResponse(execute_task_async(route, header,
original_request_data,
app.state.sockets_decode),
media_type="text/event-stream")
except Exception as e:
import sys
import traceback
@ -99,16 +121,20 @@ async def chat_completions(request: Request):
logger.error("Error occurred in disagg prefill proxy server")
logger.error(e)
logger.error("".join(traceback.format_exception(*exc_info)))
async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
logger.info(f"vLLM Disaggregate Connector start {args} {uvicorn_kwargs}")
logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs)
logger.info(args.prefill_addr)
app.state.prefill_addr = f"tcp://{args.prefill_addr}" if args.prefill_addr is not None else url_prefill
app.state.decode_addr = f"tcp://{args.decode_addr}" if args.decode_addr is not None else url_decode
logger.info(f"start connect url_prefill: {app.state.prefill_addr} url_decode: {app.state.decode_addr}")
app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr
is not None else url_prefill)
app.state.decode_addr = (f"tcp://{args.decode_addr}"
if args.decode_addr is not None else url_decode)
logger.info("start connect url_prefill: %s url_decode: %s",
app.state.prefill_addr, app.state.decode_addr)
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
@ -119,8 +145,7 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
# url = 'tcp://127.0.0.1:5555'
uvicorn.run(app, host="0.0.0.0", port=8001)
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@ -6,17 +6,16 @@ import socket
from http import HTTPStatus
from typing import Any, Optional
import uvicorn
import zmq
import zmq.asyncio
import uvicorn
from fastapi import FastAPI, Request, Response
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
@ -76,11 +75,13 @@ async def serve_http(app: FastAPI,
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
return server.shutdown()
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port)
url_worker = "inproc://workers"
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
# Prepare our context and sockets
@ -89,15 +90,18 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
# Socket to talk to clients
clients = context.socket(zmq.ROUTER)
clients.bind(url_client)
logger.info(f"ZMQ Server ROUTER started at {url_client}")
logger.info("ZMQ Server ROUTER started at %s", url_client)
# Socket to talk to workers
workers = context.socket(zmq.DEALER)
workers.bind(url_worker)
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
logger.info("ZMQ Worker DEALER started at %s", url_worker)
tasks = [
asyncio.create_task(worker_routine(url_worker, app, context, i))
for i in range(5)
]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
try:
await asyncio.gather(*tasks, proxy_task)
except KeyboardInterrupt:
@ -110,6 +114,7 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
workers.close()
context.destroy(linger=0)
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""

View File

@ -1031,7 +1031,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
zmq_server_port = args.zmq_server_port
if zmq_server_port is not None:
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
logger.info("asyncio.create_task Starting ZMQ server at port %d",
zmq_server_port)
asyncio.create_task(serve_zmq(args, zmq_server_port, app))
shutdown_task = await serve_http(

View File

@ -1,27 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
import json
import tempfile
import traceback
import uuid
from typing import Optional
import httpx
import zmq
import zmq.asyncio
import tempfile
import uuid
import httpx
import json
import traceback
from typing import Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
prometheus_multiproc_dir: tempfile.TemporaryDirectory
@ -62,49 +62,61 @@ async def worker_routine(worker_url: str, app: FastAPI,
worker_identity = f"worker-{i}-{uuid.uuid4()}"
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
socket.connect(worker_url)
logger.info(f"{worker_identity} started at {worker_url}")
logger.info("%s started at %s", worker_identity, worker_url)
while True:
identity, url, header, body = await socket.recv_multipart()
logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]")
logger.info("worker-%d Received request identity: [ %s ]",
i, identity.decode())
url_str = url.decode()
logger.info(f"worker-{i} Received request url: [{url_str} ]")
logger.info("worker-%d Received request url: [ %s ]",
i, url_str)
headers = bytes_to_headers(header)
logger.info(f"worker-{i} Received request headers: [{headers} ]")
logger.info("worker-%d Received request headers: [ %s ]",
i, headers)
body_json = json.loads(body.decode())
logger.info(f"worker-{i} Received request body: [{body_json} ]")
logger.info(f"worker-{i} Calling OpenAI API")
logger.info("worker-%d Received request body: [ %s ]",
i, body_json)
logger.info("worker-%d Calling OpenAI API", i)
completionRequest = CompletionRequest(**body_json)
createRequest = create_request(url_str, "POST", body_json, headers)
generator = await create_completion(app, completionRequest, createRequest)
logger.info(f"worker-{i} Received response: [{generator} ]")
generator = await create_completion(app, completionRequest,
createRequest)
logger.info("worker-%d Received response: [ %s ]", i, generator)
if isinstance(generator, ErrorResponse):
content = generator.model_dump_json()
context_json = json.loads(content)
context_json.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')])
await socket.send_multipart([identity, b"application/json",
json.dumps(context_json).encode('utf-8')])
elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')])
await socket.send_multipart([identity, b"application/json",
json.dumps(generator.model_dump()).encode('utf-8')])
else:
async for chunk in generator:
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')])
logger.info("worker-%d Sending response chunk: [ %s ]",
i, chunk)
await socket.send_multipart([identity,
b"text/event-stream",
chunk.encode('utf-8')])
except Exception as e:
logger.error(f"Error in worker routine: {e} worker-{i}")
logger.error("Error in worker routine: %s worker-%d", e, i)
logger.error(traceback.format_exc())
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
async def create_completion(app: FastAPI, request: CompletionRequest,
raw_request: Request):
handler = completion(app)
logger.info(f"zmq request post: {request}")
logger.info("zmq request post: %s", request)
if handler is None:
return base(app).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
logger.info(f"zmq request end post: {generator}")
logger.info("zmq request end post: %s", generator)
return generator
def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request:
def create_request(path: str, method: str, body: dict,
headers: httpx.Headers) -> Request:
scope = {
'type': 'http',
'http_version': '1.1',
@ -120,10 +132,9 @@ def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -
'body': scope.get('body', b''),
}
async def send(message):
pass
pass
return Request(scope, receive=receive, send=send)
if __name__ == "__main__":
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))