mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 15:07:11 +08:00
Run yapf and ruff
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
187f112ccd
commit
2c31e4c3ea
@ -1,31 +1,34 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
|
||||
# test connect completions we assume prefill and decode are on the same node
|
||||
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 --chat-template ~/vllm/examples/template_chatglm2.jinja
|
||||
|
||||
# test connect completions we assume prefill and decode are on the same node
|
||||
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
|
||||
# --chat-template ~/vllm/examples/template_chatglm2.jinja
|
||||
# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010
|
||||
# 3. python test_request.py
|
||||
|
||||
async def test_connect_completions(session):
|
||||
try:
|
||||
base_url = "http://localhost:8001/v1/connect/completions"
|
||||
body = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 150,
|
||||
"frequency_penalty": 1.3,
|
||||
"presence_penalty": 0.2,
|
||||
"repetition_penalty": 1.2,
|
||||
"model": "facebook/opt-125m",
|
||||
"prompt": "Can you introduce vllm?",
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 150,
|
||||
"frequency_penalty": 1.3,
|
||||
"presence_penalty": 0.2,
|
||||
"repetition_penalty": 1.2,
|
||||
"model": "facebook/opt-125m",
|
||||
"prompt": "Can you introduce vllm?",
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True
|
||||
}}
|
||||
print(f"Sending request to {base_url}, body {body}")
|
||||
async with session.post(base_url, json= body) as response:
|
||||
|
||||
}
|
||||
}
|
||||
print(f"Sending request to {base_url}, body {body}")
|
||||
async with session.post(base_url, json=body) as response:
|
||||
|
||||
print(response.status)
|
||||
print(response.headers)
|
||||
responseText = ""
|
||||
@ -40,13 +43,18 @@ async def test_connect_completions(session):
|
||||
print(f"Error decoding chunk: {chunk!r}")
|
||||
else:
|
||||
# Print the headers and JSON response
|
||||
print(f"Unexpected Transfer-Encoding: {transfer_encoding} {response.headers} {await response.json()}")
|
||||
print("Unexpected Transfer-Encoding: {} {} {}".format(
|
||||
transfer_encoding, response.headers, await
|
||||
response.json()))
|
||||
else:
|
||||
print(f"Request failed with status code {response.status}")
|
||||
print(f"baseurl {base_url} response data {extract_data(responseText)}")
|
||||
print(
|
||||
f"baseurl {base_url} response data {extract_data(responseText)}"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def extract_data(responseText):
|
||||
reply = ""
|
||||
for data in responseText.split("\n\n"):
|
||||
@ -66,7 +74,7 @@ def extract_data(responseText):
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = []
|
||||
@ -76,4 +84,3 @@ async def main():
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import signal
|
||||
import uuid
|
||||
# from fastapi.lifespan import Lifespan
|
||||
from asyncio import Queue
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from contextlib import asynccontextmanager
|
||||
# from fastapi.lifespan import Lifespan
|
||||
from asyncio import Queue
|
||||
import uuid
|
||||
import signal
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# default prefill and decode url
|
||||
@ -21,55 +24,69 @@ socket_decode_num = 5
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.connect')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# create socket pool with prefill and decode
|
||||
logger.info("start create_socket_pool")
|
||||
app.state.zmqctx = zmq.asyncio.Context()
|
||||
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
||||
app.state.sockets_prefill = await create_socket_pool(
|
||||
app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket_pool sockets_prefill")
|
||||
app.state.sockets_decode = await create_socket_pool(app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
|
||||
app.state.sockets_decode = await create_socket_pool(
|
||||
app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket_pool sockets_decode")
|
||||
yield
|
||||
## close zmq context
|
||||
logger.info("term zmqctx")
|
||||
app.state.zmqctx.destroy(linger=0)
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
# create async socket pool with num_sockets use ZMQ_DEALER
|
||||
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue:
|
||||
sockets = Queue()
|
||||
async def create_socket_pool(url: str, num_sockets: int,
|
||||
zmqctx: zmq.asyncio.Context) -> Queue:
|
||||
sockets: Queue = Queue()
|
||||
for i in range(num_sockets):
|
||||
sock = zmqctx.socket(zmq.DEALER)
|
||||
identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
sock.setsockopt(zmq.IDENTITY, identity.encode())
|
||||
sock.connect(url)
|
||||
logger.info(f"{identity} started at {url} {sockets.qsize()}")
|
||||
logger.info("%s started at %s with queue size %s", identity, url,
|
||||
sockets.qsize())
|
||||
await sockets.put(sock)
|
||||
return sockets
|
||||
|
||||
|
||||
# select a socket and execute task
|
||||
async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue):
|
||||
async def execute_task_async(route: str, headers: dict, request: dict,
|
||||
sockets: Queue):
|
||||
sock = await sockets.get()
|
||||
try:
|
||||
requestBody = json.dumps(request)
|
||||
headersJson = json.dumps(headers)
|
||||
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
|
||||
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
|
||||
logger.info(f"Sent end")
|
||||
logger.info("Sending requestBody: %s to %s with headers: %s",
|
||||
requestBody, route, headersJson)
|
||||
await sock.send_multipart(
|
||||
[route.encode(),
|
||||
headersJson.encode(),
|
||||
requestBody.encode()])
|
||||
logger.info("Sent end")
|
||||
while True:
|
||||
logger.info(f"Waiting for reply")
|
||||
logger.info("Waiting for reply")
|
||||
[contentType, reply] = await sock.recv_multipart()
|
||||
logger.info(f"Received result: {contentType}, {reply}")
|
||||
logger.info("Received result: %s, %s", contentType, reply)
|
||||
reply = reply.decode()
|
||||
yield f"{reply}"
|
||||
if "[DONE]" in reply:
|
||||
logger.info(f"Received stop signal, return socket")
|
||||
logger.info("Received stop signal, return socket")
|
||||
break
|
||||
finally:
|
||||
await sockets.put(sock)
|
||||
|
||||
|
||||
@app.post('/v1/connect/completions')
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
@ -77,21 +94,26 @@ async def chat_completions(request: Request):
|
||||
x_request_id = str(uuid.uuid4())
|
||||
header = dict(request.headers)
|
||||
if header.get("X-Request-Id") is None:
|
||||
logger.info(f"add X-Request-Id: {x_request_id}")
|
||||
logger.info("add X-Request-Id: %s", x_request_id)
|
||||
header["X-Request-Id"] = x_request_id
|
||||
original_request_data = await request.json()
|
||||
logger.info(f"Received request: {original_request_data} header: {header}")
|
||||
logger.info("Received request: %s header: %s", original_request_data,
|
||||
header)
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request['max_tokens'] = 1
|
||||
route = "/v1/completions"
|
||||
# finish prefill
|
||||
async for _ in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
|
||||
async for _ in execute_task_async(route, header, prefill_request,
|
||||
app.state.sockets_prefill):
|
||||
continue
|
||||
|
||||
# return decode
|
||||
return StreamingResponse(execute_task_async(route, header,original_request_data, app.state.sockets_decode), media_type="text/event-stream")
|
||||
|
||||
return StreamingResponse(execute_task_async(route, header,
|
||||
original_request_data,
|
||||
app.state.sockets_decode),
|
||||
media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
@ -99,16 +121,20 @@ async def chat_completions(request: Request):
|
||||
logger.error("Error occurred in disagg prefill proxy server")
|
||||
logger.error(e)
|
||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
|
||||
async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
logger.info(f"vLLM Disaggregate Connector start {args} {uvicorn_kwargs}")
|
||||
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
logger.info(args.prefill_addr)
|
||||
|
||||
app.state.prefill_addr = f"tcp://{args.prefill_addr}" if args.prefill_addr is not None else url_prefill
|
||||
app.state.decode_addr = f"tcp://{args.decode_addr}" if args.decode_addr is not None else url_decode
|
||||
logger.info(f"start connect url_prefill: {app.state.prefill_addr} url_decode: {app.state.decode_addr}")
|
||||
|
||||
app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr
|
||||
is not None else url_prefill)
|
||||
app.state.decode_addr = (f"tcp://{args.decode_addr}"
|
||||
if args.decode_addr is not None else url_decode)
|
||||
logger.info("start connect url_prefill: %s url_decode: %s",
|
||||
app.state.prefill_addr, app.state.decode_addr)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
@ -119,8 +145,7 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# url = 'tcp://127.0.0.1:5555'
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
||||
@ -6,17 +6,16 @@ import socket
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Optional
|
||||
|
||||
import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm.entrypoints.openai.connect_worker import worker_routine
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.entrypoints.openai.connect_worker import worker_routine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
|
||||
@ -76,11 +75,13 @@ async def serve_http(app: FastAPI,
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port, process, " ".join(process.cmdline()))
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
"""Server routine"""
|
||||
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
|
||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||
zmq_server_port)
|
||||
url_worker = "inproc://workers"
|
||||
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
|
||||
# Prepare our context and sockets
|
||||
@ -89,15 +90,18 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
# Socket to talk to clients
|
||||
clients = context.socket(zmq.ROUTER)
|
||||
clients.bind(url_client)
|
||||
logger.info(f"ZMQ Server ROUTER started at {url_client}")
|
||||
logger.info("ZMQ Server ROUTER started at %s", url_client)
|
||||
# Socket to talk to workers
|
||||
workers = context.socket(zmq.DEALER)
|
||||
workers.bind(url_worker)
|
||||
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
|
||||
logger.info("ZMQ Worker DEALER started at %s", url_worker)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(url_worker, app, context, i))
|
||||
for i in range(5)
|
||||
]
|
||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||
|
||||
tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
|
||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, proxy_task)
|
||||
except KeyboardInterrupt:
|
||||
@ -110,6 +114,7 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
workers.close()
|
||||
context.destroy(linger=0)
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""Adds handlers for fatal errors that should crash the server"""
|
||||
|
||||
|
||||
@ -1031,7 +1031,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
zmq_server_port = args.zmq_server_port
|
||||
if zmq_server_port is not None:
|
||||
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
|
||||
logger.info("asyncio.create_task Starting ZMQ server at port %d",
|
||||
zmq_server_port)
|
||||
asyncio.create_task(serve_zmq(args, zmq_server_port, app))
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
|
||||
@ -1,27 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
import httpx
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
@ -62,49 +62,61 @@ async def worker_routine(worker_url: str, app: FastAPI,
|
||||
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
||||
socket.connect(worker_url)
|
||||
logger.info(f"{worker_identity} started at {worker_url}")
|
||||
logger.info("%s started at %s", worker_identity, worker_url)
|
||||
while True:
|
||||
identity, url, header, body = await socket.recv_multipart()
|
||||
logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]")
|
||||
logger.info("worker-%d Received request identity: [ %s ]",
|
||||
i, identity.decode())
|
||||
url_str = url.decode()
|
||||
logger.info(f"worker-{i} Received request url: [{url_str} ]")
|
||||
logger.info("worker-%d Received request url: [ %s ]",
|
||||
i, url_str)
|
||||
headers = bytes_to_headers(header)
|
||||
logger.info(f"worker-{i} Received request headers: [{headers} ]")
|
||||
logger.info("worker-%d Received request headers: [ %s ]",
|
||||
i, headers)
|
||||
body_json = json.loads(body.decode())
|
||||
logger.info(f"worker-{i} Received request body: [{body_json} ]")
|
||||
logger.info(f"worker-{i} Calling OpenAI API")
|
||||
logger.info("worker-%d Received request body: [ %s ]",
|
||||
i, body_json)
|
||||
logger.info("worker-%d Calling OpenAI API", i)
|
||||
completionRequest = CompletionRequest(**body_json)
|
||||
createRequest = create_request(url_str, "POST", body_json, headers)
|
||||
generator = await create_completion(app, completionRequest, createRequest)
|
||||
logger.info(f"worker-{i} Received response: [{generator} ]")
|
||||
generator = await create_completion(app, completionRequest,
|
||||
createRequest)
|
||||
logger.info("worker-%d Received response: [ %s ]", i, generator)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = generator.model_dump_json()
|
||||
context_json = json.loads(content)
|
||||
context_json.append("status_code", generator.code)
|
||||
await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')])
|
||||
await socket.send_multipart([identity, b"application/json",
|
||||
json.dumps(context_json).encode('utf-8')])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')])
|
||||
await socket.send_multipart([identity, b"application/json",
|
||||
json.dumps(generator.model_dump()).encode('utf-8')])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
|
||||
await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')])
|
||||
logger.info("worker-%d Sending response chunk: [ %s ]",
|
||||
i, chunk)
|
||||
await socket.send_multipart([identity,
|
||||
b"text/event-stream",
|
||||
chunk.encode('utf-8')])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker routine: {e} worker-{i}")
|
||||
logger.error("Error in worker routine: %s worker-%d", e, i)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
|
||||
async def create_completion(app: FastAPI, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
handler = completion(app)
|
||||
logger.info(f"zmq request post: {request}")
|
||||
logger.info("zmq request post: %s", request)
|
||||
if handler is None:
|
||||
return base(app).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
logger.info(f"zmq request end post: {generator}")
|
||||
logger.info("zmq request end post: %s", generator)
|
||||
return generator
|
||||
|
||||
|
||||
def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request:
|
||||
def create_request(path: str, method: str, body: dict,
|
||||
headers: httpx.Headers) -> Request:
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
@ -120,10 +132,9 @@ def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -
|
||||
'body': scope.get('body', b''),
|
||||
}
|
||||
async def send(message):
|
||||
pass
|
||||
pass
|
||||
return Request(scope, receive=receive, send=send)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user