[PD] let toy proxy handle /chat/completions (#19730)

Signed-off-by: Linkun <github@lkchen.net>
This commit is contained in:
lkchen 2025-06-25 12:17:45 -07:00 committed by GitHub
parent 8b8c209e35
commit 4734704b30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -196,8 +196,7 @@ async def stream_service_response(client_info: dict, endpoint: str,
yield chunk yield chunk
@app.post("/v1/completions") async def _handle_completions(api: str, request: Request):
async def handle_completions(request: Request):
try: try:
req_data = await request.json() req_data = await request.json()
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
@ -206,9 +205,8 @@ async def handle_completions(request: Request):
prefill_client_info = get_next_client(request.app, 'prefill') prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service # Send request to prefill service
response = await send_request_to_service(prefill_client_info, response = await send_request_to_service(prefill_client_info, api,
"/completions", req_data, req_data, request_id)
request_id)
# Extract the needed fields # Extract the needed fields
response_json = response.json() response_json = response.json()
@ -224,7 +222,7 @@ async def handle_completions(request: Request):
# Stream response from decode service # Stream response from decode service
async def generate_stream(): async def generate_stream():
async for chunk in stream_service_response(decode_client_info, async for chunk in stream_service_response(decode_client_info,
"/completions", api,
req_data, req_data,
request_id=request_id): request_id=request_id):
yield chunk yield chunk
@ -237,12 +235,22 @@ async def handle_completions(request: Request):
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server" print("Error occurred in disagg prefill proxy server"
" - completions endpoint") f" - {api} endpoint")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
raise raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck") @app.get("/healthcheck")
async def healthcheck(): async def healthcheck():
"""Simple endpoint to check if the server is running.""" """Simple endpoint to check if the server is running."""