# SPDX-License-Identifier: Apache-2.0 import argparse import os import time from contextlib import asynccontextmanager import httpx import numpy as np from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse @asynccontextmanager async def lifespan(app: FastAPI): """ Lifespan context manager to handle startup and shutdown events. """ # Startup: Initialize clients prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' app.state.prefill_client = httpx.AsyncClient(timeout=None, base_url=prefiller_base_url) app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url) yield # Shutdown: Close clients await app.state.prefill_client.aclose() await app.state.decode_client.aclose() # Update FastAPI app initialization to use lifespan app = FastAPI(lifespan=lifespan) class StatsCalculator: def __init__(self): self._stats = [] self._last_log_time = time.time() def add(self, value): self._stats.append(value) if time.time() - self._last_log_time > 5: self._log_stats() self._last_log_time = time.time() def _log_stats(self): # Print average, median, and 99th percentile np_arr = np.array(self._stats) output_str = f"\nNum requests: {len(self._stats)}" + \ "\nPrefill node TTFT stats:" + \ f"\n - Average (ms): {np.mean(np_arr)}" + \ f"\n - Median (ms): {np.median(np_arr)}" + \ f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" print("===============================", output_str, "===============================") stats_calculator = StatsCalculator() counter = 0 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8000) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--prefiller-host", type=str, default="localhost") parser.add_argument("--prefiller-port", type=int, default=8100) parser.add_argument("--decoder-host", type=str, default="localhost") parser.add_argument("--decoder-port", type=int, default=8200) args = parser.parse_args() return args # Initialize variables to hold the persistent clients app.state.prefill_client = None app.state.decode_client = None async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, req_data: dict): """ Send a request to a service using a persistent client. """ req_data = req_data.copy() req_data['max_tokens'] = 1 if 'max_completion_tokens' in req_data: req_data['max_completion_tokens'] = 1 headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() return response async def stream_service_response(client: httpx.AsyncClient, endpoint: str, req_data: dict): """ Asynchronously stream the response from a service using a persistent client. """ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} async with client.stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk @app.post("/v1/completions") async def handle_completions(request: Request): global counter, stats_calculator counter += 1 st = time.time() try: req_data = await request.json() # Send request to prefill service, ignore the response await send_request_to_service(app.state.prefill_client, "/completions", req_data) et = time.time() stats_calculator.add(et - st) # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response(app.state.decode_client, "/completions", req_data): yield chunk return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import sys import traceback exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server" " - completions endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @app.post("/v1/chat/completions") async def handle_chat_completions(request: Request): global counter, stats_calculator counter += 1 st = time.time() try: req_data = await request.json() # Send request to prefill service, ignore the response await send_request_to_service(app.state.prefill_client, "/chat/completions", req_data) et = time.time() stats_calculator.add(et - st) # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response(app.state.decode_client, "/chat/completions", req_data): yield chunk return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import sys import traceback exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server " " - chat completions endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn uvicorn.run(app, host=global_args.host, port=global_args.port)