diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index e9d3342ac8a83..5e63a02c97767 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -2,7 +2,7 @@ import argparse import json from typing import AsyncGenerator -from fastapi import BackgroundTasks, FastAPI, Request +from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn @@ -44,14 +44,8 @@ async def generate(request: Request) -> Response: ret = {"text": text_outputs} yield (json.dumps(ret) + "\0").encode("utf-8") - async def abort_request() -> None: - await engine.abort(request_id) - if stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) - return StreamingResponse(stream_results(), background=background_tasks) + return StreamingResponse(stream_results()) # Non-streaming case final_output = None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bc827b3441954..eed3bebb8eb36 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union import fastapi import uvicorn -from fastapi import BackgroundTasks, Request +from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse @@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest, result_generator = engine.generate(prompt, sampling_params, request_id, token_ids) - async def abort_request() -> None: - await engine.abort(request_id) - def create_stream_response_json( index: int, text: str, @@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest, # Streaming response if request.stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream", - background=background_tasks) + media_type="text/event-stream") # Non-streaming response final_res: RequestOutput = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await abort_request() + await engine.abort(request_id) return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") final_res = res @@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): and (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) - async def abort_request() -> None: - await engine.abort(request_id) - def create_stream_response_json( index: int, text: str, @@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request): # Streaming response if stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream", - background=background_tasks) + media_type="text/event-stream") # Non-streaming response final_res: RequestOutput = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await abort_request() + await engine.abort(request_id) return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") final_res = res