diff --git a/benchmarks/disagg_benchmarks/zmq/test_connect_server1.py b/benchmarks/disagg_benchmarks/zmq/test_connect_server1.py index 8747acfea0cce..43a6b5c1c17f0 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_connect_server1.py +++ b/benchmarks/disagg_benchmarks/zmq/test_connect_server1.py @@ -14,8 +14,11 @@ async def worker_routine(worker_url: str, socket.connect(worker_url) print(f"worker-{i} {worker_url} started") while True: - identity, string = await socket.recv_multipart() - print(f"worker-{i} Received request: [{identity} {string} ]") + identity, url, headers, string = await socket.recv_multipart() + print(f"worker-{i} Received request identity: [{identity} ]") + print(f"worker-{i} Received request url: [{url} ]") + print(f"worker-{i} Received request headers: [{headers} ]") + print(f"worker-{i} Received request string: [{string} ]") streamreply = ['{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}', '{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}', '{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}' diff --git a/benchmarks/disagg_benchmarks/zmq/test_connect_server2.py b/benchmarks/disagg_benchmarks/zmq/test_connect_server2.py index 6b7ed4da6554d..f9f4d1a5ce7ee 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_connect_server2.py +++ b/benchmarks/disagg_benchmarks/zmq/test_connect_server2.py @@ -14,8 +14,11 @@ async def worker_routine(worker_url: str, socket.connect(worker_url) print(f"worker-{i} {worker_url} started") while True: - identity, string = await socket.recv_multipart() - print(f"worker-{i} Received request: [{identity} {string} ]") + identity, url, headers, string = await socket.recv_multipart() + print(f"worker-{i} Received request identity: [{identity} ]") + print(f"worker-{i} Received request url: [{url} ]") + print(f"worker-{i} Received request headers: [{headers} ]") + print(f"worker-{i} Received request string: [{string} ]") streamreply = ['{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}', '{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}', '{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}' diff --git a/benchmarks/disagg_benchmarks/zmq/test_request.py b/benchmarks/disagg_benchmarks/zmq/test_request.py index 59af34d162b6e..f5c247d7d3e97 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_request.py +++ b/benchmarks/disagg_benchmarks/zmq/test_request.py @@ -24,7 +24,7 @@ async def test_connect(session): "stream_options": { "include_usage": True } -}) as response: +}, headers={"Content-Type": "application/json"}) as response: print(response.status) if response.status == 200: transfer_encoding = response.headers.get('Transfer-Encoding') diff --git a/vllm/entrypoints/connect.py b/vllm/entrypoints/connect.py index a5b89aba26108..a7bb5a9daccf7 100644 --- a/vllm/entrypoints/connect.py +++ b/vllm/entrypoints/connect.py @@ -3,6 +3,7 @@ import uvicorn import zmq import zmq.asyncio from fastapi import FastAPI, Request +from starlette.datastructures import Headers from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager # from fastapi.lifespan import Lifespan @@ -20,7 +21,6 @@ socket_decode_num = 5 # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.connect') - @asynccontextmanager async def lifespan(app: FastAPI): # create scoket pool with prefill and decode @@ -50,12 +50,13 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con return sockets # select a scoket and execute task -async def execute_task_async(request: dict, sockets: list): +async def execute_task_async(route: str, headers: Headers, request: dict, sockets: list): sock = await sockets.get() try: requestBody = json.dumps(request) - logger.info(f"Sending requestBody: {requestBody}") - await sock.send(requestBody.encode()) + headersJson = json.dumps(dict(headers)) + logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}") + await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()]) logger.info(f"Sent end") while True: logger.info(f"Waiting for reply") @@ -73,18 +74,19 @@ async def execute_task_async(request: dict, sockets: list): async def chat_completions(request: Request): try: original_request_data = await request.json() + header = request.headers logger.info(f"Received request: {original_request_data}") prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 - + route = "/v1/completions" # finish prefill - async for x in execute_task_async(prefill_request, app.state.sockets_prefill): + async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill): logger.info(f"{x}") continue # return decode - return StreamingResponse(execute_task_async(original_request_data, app.state.sockets_decode), media_type="text/event-stream") + return StreamingResponse(execute_task_async(route, header,original_request_data, app.state.sockets_decode), media_type="text/event-stream") except Exception as e: import sys