mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 20:22:18 +08:00
add /v1/completions stream support
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
905424ed65
commit
bfde1688e7
@ -50,22 +50,22 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
|
|||||||
return sockets
|
return sockets
|
||||||
|
|
||||||
# select a scoket and execute task
|
# select a scoket and execute task
|
||||||
async def execute_task_async(route: str, headers: Headers, request: dict, sockets: list):
|
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
|
||||||
sock = await sockets.get()
|
sock = await sockets.get()
|
||||||
try:
|
try:
|
||||||
requestBody = json.dumps(request)
|
requestBody = json.dumps(request)
|
||||||
headersJson = json.dumps(dict(headers))
|
headersJson = json.dumps(headers)
|
||||||
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
|
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
|
||||||
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
|
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
|
||||||
logger.info(f"Sent end")
|
logger.info(f"Sent end")
|
||||||
while True:
|
while True:
|
||||||
logger.info(f"Waiting for reply")
|
logger.info(f"Waiting for reply")
|
||||||
reply = await sock.recv_multipart()
|
[contentType, reply] = await sock.recv_multipart()
|
||||||
logger.info(f"Received result: {reply}")
|
logger.info(f"Received result: {contentType}, {reply}")
|
||||||
yield f"data: {reply[0].decode()}\n\n"
|
reply = reply.decode()
|
||||||
if "finish_reason" in reply[0].decode() and "stop" in reply[0].decode():
|
yield f"{reply}"
|
||||||
|
if "[DONE]" in reply:
|
||||||
logger.info(f"Received stop signal, return socket")
|
logger.info(f"Received stop signal, return socket")
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
await sockets.put(sock)
|
await sockets.put(sock)
|
||||||
@ -73,16 +73,20 @@ async def execute_task_async(route: str, headers: Headers, request: dict, socket
|
|||||||
@app.post('/v1/connect/completions')
|
@app.post('/v1/connect/completions')
|
||||||
async def chat_completions(request: Request):
|
async def chat_completions(request: Request):
|
||||||
try:
|
try:
|
||||||
|
# Add the X-Request-Id header to the raw headers list
|
||||||
|
x_request_id = str(uuid.uuid4())
|
||||||
|
header = dict(request.headers)
|
||||||
|
if header.get("X-Request-Id") is None:
|
||||||
|
logger.info(f"add X-Request-Id: {x_request_id}")
|
||||||
|
header["X-Request-Id"] = x_request_id
|
||||||
original_request_data = await request.json()
|
original_request_data = await request.json()
|
||||||
header = request.headers
|
logger.info(f"Received request: {original_request_data} header: {header}")
|
||||||
logger.info(f"Received request: {original_request_data}")
|
|
||||||
prefill_request = original_request_data.copy()
|
prefill_request = original_request_data.copy()
|
||||||
# change max_tokens = 1 to let it only do prefill
|
# change max_tokens = 1 to let it only do prefill
|
||||||
prefill_request['max_tokens'] = 1
|
prefill_request['max_tokens'] = 1
|
||||||
route = "/v1/completions"
|
route = "/v1/completions"
|
||||||
# finish prefill
|
# finish prefill
|
||||||
async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
|
async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
|
||||||
logger.info(f"{x}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# return decode
|
# return decode
|
||||||
|
|||||||
@ -6,9 +6,13 @@ import socket
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Request, Response
|
from fastapi import FastAPI, Request, Response
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.connect_worker import worker_routine
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||||
@ -74,6 +78,37 @@ async def serve_http(app: FastAPI,
|
|||||||
logger.info("Shutting down FastAPI HTTP server.")
|
logger.info("Shutting down FastAPI HTTP server.")
|
||||||
return server.shutdown()
|
return server.shutdown()
|
||||||
|
|
||||||
|
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||||
|
"""Server routine"""
|
||||||
|
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
|
||||||
|
url_worker = "inproc://workers"
|
||||||
|
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
|
||||||
|
# Prepare our context and sockets
|
||||||
|
context = zmq.asyncio.Context()
|
||||||
|
|
||||||
|
# Socket to talk to clients
|
||||||
|
clients = context.socket(zmq.ROUTER)
|
||||||
|
clients.bind(url_client)
|
||||||
|
logger.info(f"ZMQ Server ROUTER started at {url_client}")
|
||||||
|
# Socket to talk to workers
|
||||||
|
workers = context.socket(zmq.DEALER)
|
||||||
|
workers.bind(url_worker)
|
||||||
|
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
|
||||||
|
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*tasks, proxy_task)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("ZMQ Server interrupted")
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
print("ZMQError:", e)
|
||||||
|
finally:
|
||||||
|
# We never get here but clean up anyhow
|
||||||
|
clients.close()
|
||||||
|
workers.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||||
"""Adds handlers for fatal errors that should crash the server"""
|
"""Adds handlers for fatal errors that should crash the server"""
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
|||||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import load_chat_template
|
from vllm.entrypoints.chat_utils import load_chat_template
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http, serve_zmq
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
validate_parsed_serve_args)
|
validate_parsed_serve_args)
|
||||||
@ -1029,6 +1029,11 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||||
sock_addr[1])
|
sock_addr[1])
|
||||||
|
|
||||||
|
zmq_server_port = args.zmq_server_port
|
||||||
|
if zmq_server_port is not None:
|
||||||
|
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
|
||||||
|
asyncio.create_task(serve_zmq(args, zmq_server_port, app))
|
||||||
|
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
|
|||||||
129
vllm/entrypoints/openai/connect_worker.py
Normal file
129
vllm/entrypoints/openai/connect_worker.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
ErrorResponse)
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
|
|
||||||
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
|
logger = init_logger('vllm.entrypoints.openai.connect_worker')
|
||||||
|
|
||||||
|
def base(app: FastAPI) -> OpenAIServing:
|
||||||
|
# Reuse the existing instance
|
||||||
|
return tokenization(app)
|
||||||
|
|
||||||
|
|
||||||
|
def models(app: FastAPI) -> OpenAIServingModels:
|
||||||
|
return app.state.openai_serving_models
|
||||||
|
|
||||||
|
|
||||||
|
def chat(app: FastAPI) -> Optional[OpenAIServingChat]:
|
||||||
|
return app.state.openai_serving_chat
|
||||||
|
|
||||||
|
|
||||||
|
def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]:
|
||||||
|
return app.state.openai_serving_completion
|
||||||
|
|
||||||
|
def tokenization(app: FastAPI) -> OpenAIServingTokenization:
|
||||||
|
return app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
||||||
|
headers_dict = json.loads(bytes_data.decode())
|
||||||
|
return httpx.Headers(headers_dict)
|
||||||
|
|
||||||
|
async def worker_routine(worker_url: str, app: FastAPI,
|
||||||
|
context: zmq.asyncio.Context = None, i: int = 0):
|
||||||
|
"""Worker routine"""
|
||||||
|
try:
|
||||||
|
# Socket to talk to dispatcher
|
||||||
|
socket = context.socket(zmq.DEALER)
|
||||||
|
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
||||||
|
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
||||||
|
socket.connect(worker_url)
|
||||||
|
logger.info(f"{worker_identity} started at {worker_url}")
|
||||||
|
while True:
|
||||||
|
identity, url, header, body = await socket.recv_multipart()
|
||||||
|
logger.info(f"worker-{i} Received request identity: [{identity} ]")
|
||||||
|
url = url.decode()
|
||||||
|
logger.info(f"worker-{i} Received request url: [{url} ]")
|
||||||
|
header = bytes_to_headers(header)
|
||||||
|
logger.info(f"worker-{i} Received request headers: [{header} ]")
|
||||||
|
body = json.loads(body.decode())
|
||||||
|
logger.info(f"worker-{i} Received request body: [{body} ]")
|
||||||
|
logger.info(f"worker-{i} Calling OpenAI API")
|
||||||
|
completionRequest = CompletionRequest(**body)
|
||||||
|
createRequest = create_request(url, "POST", body, header)
|
||||||
|
generator = await create_completion(app, completionRequest, createRequest)
|
||||||
|
logger.info(f"worker-{i} Received response: [{generator} ]")
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
content = generator.model_dump_json()
|
||||||
|
context = json.loads(content)
|
||||||
|
context.append("status_code", generator.code)
|
||||||
|
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
|
||||||
|
elif isinstance(generator, CompletionResponse):
|
||||||
|
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
|
||||||
|
else:
|
||||||
|
async for chunk in generator:
|
||||||
|
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
|
||||||
|
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in worker routine: {e} worker-{i}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
|
||||||
|
handler = completion(app)
|
||||||
|
logger.info(f"zmq requset post: {request}")
|
||||||
|
if handler is None:
|
||||||
|
return base(app).create_error_response(
|
||||||
|
message="The model does not support Completions API")
|
||||||
|
|
||||||
|
generator = await handler.create_completion(request, raw_request)
|
||||||
|
logger.info(f"zmq requset end post: {generator}")
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
def create_request(path: str, method: str, body: bytes, headers: dict = None):
|
||||||
|
scope = {
|
||||||
|
'type': 'http',
|
||||||
|
'http_version': '1.1',
|
||||||
|
'method': method,
|
||||||
|
'path': path,
|
||||||
|
'headers': list(headers.items()) if headers else [],
|
||||||
|
}
|
||||||
|
if body:
|
||||||
|
scope['body'] = json.dumps(body).encode('utf-8')
|
||||||
|
async def receive():
|
||||||
|
return {
|
||||||
|
'type': 'http.request',
|
||||||
|
'body': scope.get('body', b''),
|
||||||
|
}
|
||||||
|
async def send(message):
|
||||||
|
pass
|
||||||
|
return Request(scope, receive=receive, send=send)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user