From b2c06509e58d8afefc1b5fb0f3d91f0cc9d9f279 Mon Sep 17 00:00:00 2001 From: frankie Date: Fri, 15 Aug 2025 15:01:48 +0800 Subject: [PATCH] [P/D]Provide bucket algorithm rate limiter for proxy_server (#22643) Signed-off-by: frankie-ys Signed-off-by: frankie Co-authored-by: Cyrus Leung Co-authored-by: Kuntai Du --- .../disagg_prefill_proxy_server.py | 224 ++++++++++++++---- benchmarks/disagg_benchmarks/rate_limiter.py | 45 ++++ benchmarks/disagg_benchmarks/request_queue.py | 39 +++ 3 files changed, 264 insertions(+), 44 deletions(-) create mode 100644 benchmarks/disagg_benchmarks/rate_limiter.py create mode 100644 benchmarks/disagg_benchmarks/request_queue.py diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index f62d8102e2d9..904f80534914 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,63 +1,199 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import logging import os import aiohttp -from quart import Quart, make_response, request +from quart import Quart, Response, make_response, request +from rate_limiter import RateLimiter +from request_queue import RequestQueue -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) - -app = Quart(__name__) +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -async def forward_request(url, data): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +def parse_args(): + """parse command line arguments""" + parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server") + + # Add args + parser.add_argument( + "--timeout", + type=float, + default=300, + help="Timeout for backend service requests in seconds (default: 300)", + ) + parser.add_argument( + "--max-concurrent", + type=int, + default=100, + help="Maximum concurrent requests to backend services (default: 100)", + ) + parser.add_argument( + "--queue-size", + type=int, + default=500, + help="Maximum number of requests in the queue (default: 500)", + ) + parser.add_argument( + "--rate-limit", + type=int, + default=40, + help="Maximum requests per second (default: 40)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + ) + parser.add_argument( + "--prefill-url", + type=str, + default="http://localhost:8100/v1/completions", + help="Prefill service endpoint URL", + ) + parser.add_argument( + "--decode-url", + type=str, + default="http://localhost:8200/v1/completions", + help="Decode service endpoint URL", + ) + + return parser.parse_args() + + +def main(): + """parse command line arguments""" + args = parse_args() + + # Initialize configuration using command line parameters + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout) + MAX_CONCURRENT_REQUESTS = args.max_concurrent + REQUEST_QUEUE_SIZE = args.queue_size + RATE_LIMIT = args.rate_limit + PREFILL_SERVICE_URL = args.prefill_url + DECODE_SERVICE_URL = args.decode_url + PORT = args.port + + app = Quart(__name__) + + # Initialize the rate limiter and request queue + rate_limiter = RateLimiter(RATE_LIMIT) + request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE) + + # Attach the configuration object to the application instance + app.config.update( + { + "AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT, + "rate_limiter": rate_limiter, + "request_queue": request_queue, + "PREFILL_SERVICE_URL": PREFILL_SERVICE_URL, + "DECODE_SERVICE_URL": DECODE_SERVICE_URL, + } + ) + + # Start queue processing on app startup + @app.before_serving + async def startup(): + """Start request processing task when app starts serving""" + asyncio.create_task(request_queue.process()) + + async def forward_request(url, data): + """Forward request to backend service with rate limiting and error handling""" headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} - async with session.post(url=url, json=data, headers=headers) as response: - if response.status == 200: - # if response.headers.get('Transfer-Encoding') == 'chunked': - if True: - async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes - else: - content = await response.read() - yield content - -@app.route("/v1/completions", methods=["POST"]) -async def handle_request(): - try: - original_request_data = await request.get_json() - - prefill_request = original_request_data.copy() - # change max_tokens = 1 to let it only do prefill - prefill_request["max_tokens"] = 1 - - # finish prefill - async for _ in forward_request( - "http://localhost:8100/v1/completions", prefill_request + # Use rate limiter as context manager + async with ( + rate_limiter, + aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session, ): - continue + try: + async with session.post( + url=url, json=data, headers=headers + ) as response: + if response.status == 200: + # Stream response chunks + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + # Handle backend service errors + error_text = await response.text() + logger.error( + "Backend service error: %s - %s", + response.status, + error_text, + ) + yield b'{"error": "Backend service error"}' + except aiohttp.ClientError as e: + # Handle connection errors + logger.error("Connection error to %s: %s", url, str(e)) + yield b'{"error": "Service unavailable"}' + except asyncio.TimeoutError: + # Handle timeout errors + logger.error("Timeout connecting to %s", url) + yield b'{"error": "Service timeout"}' - # return decode - generator = forward_request( - "http://localhost:8200/v1/completions", original_request_data - ) - response = await make_response(generator) - response.timeout = None + async def process_request(): + """Process a single request through prefill and decode stages""" + try: + original_request_data = await request.get_json() - return response + # Create prefill request (max_tokens=1) + prefill_request = original_request_data.copy() + prefill_request["max_tokens"] = 1 - except Exception as e: - import sys - import traceback + # Execute prefill stage + async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request): + continue - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server") - print(e) - print("".join(traceback.format_exception(*exc_info))) + # Execute decode stage and stream response + generator = forward_request(DECODE_SERVICE_URL, original_request_data) + response = await make_response(generator) + response.timeout = None # Disable timeout for streaming response + return response + + except Exception: + logger.exception("Error processing request") + return Response( + response=b'{"error": "Internal server error"}', + status=500, + content_type="application/json", + ) + + @app.route("/v1/completions", methods=["POST"]) + async def handle_request(): + """Handle incoming API requests with concurrency and rate limiting""" + # Create task for request processing + task = asyncio.create_task(process_request()) + + # Enqueue request or reject if queue is full + if not await request_queue.enqueue(task): + return Response( + response=b'{"error": "Server busy, try again later"}', + status=503, + content_type="application/json", + ) + + try: + # Return the response from the processing task + return await task + except asyncio.CancelledError: + # Handle task cancellation (timeout or queue full) + logger.warning("Request cancelled due to timeout or queue full") + return Response( + response=b'{"error": "Request cancelled"}', + status=503, + content_type="application/json", + ) + + # Start the Quart server with host can be set to 0.0.0.0 + app.run(port=PORT) if __name__ == "__main__": - app.run(port=8000) + main() diff --git a/benchmarks/disagg_benchmarks/rate_limiter.py b/benchmarks/disagg_benchmarks/rate_limiter.py new file mode 100644 index 000000000000..87ac8cb6ab1a --- /dev/null +++ b/benchmarks/disagg_benchmarks/rate_limiter.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time + + +class RateLimiter: + """Token bucket rate limiter implementation""" + + def __init__(self, rate_limit): + self.rate_limit = rate_limit # Requests per second + self.num_available_tokens = rate_limit # Available tokens + self.last_refill = time.monotonic() # Last token refill time + self.lock = asyncio.Lock() # Synchronization lock + + async def acquire(self): + """Acquire a token from the rate limiter""" + while True: + async with self.lock: + current_time = time.monotonic() + elapsed = current_time - self.last_refill + + # Refill num_available_tokens if more than 1 second has passed + if elapsed > 1.0: + self.num_available_tokens = self.rate_limit + self.last_refill = current_time + + # Check if num_available_tokens are available + if self.num_available_tokens > 0: + self.num_available_tokens -= 1 + return True + + # Calculate wait time if no num_available_tokens available + wait_time = 1.0 - elapsed + await asyncio.sleep(wait_time) + + async def __aenter__(self): + """Enter async context manager - acquire token""" + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + """Exit async context manager - no cleanup needed""" + pass diff --git a/benchmarks/disagg_benchmarks/request_queue.py b/benchmarks/disagg_benchmarks/request_queue.py new file mode 100644 index 000000000000..410bcb956050 --- /dev/null +++ b/benchmarks/disagg_benchmarks/request_queue.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from collections import deque + + +class RequestQueue: + """Request queue manager with concurrency control""" + + def __init__(self, max_concurrent, max_queue_size): + # Maximum concurrent requests + self.max_concurrent = max_concurrent + self.max_queue_size = max_queue_size # Maximum queue size + # Concurrency control + self.semaphore = asyncio.Semaphore(max_concurrent) + self.queue = deque() # Request queue + self.queue_size = 0 # Current queue size + self.lock = asyncio.Lock() # Sync queue Lock + + async def enqueue(self, task): + """Add a request task to the queue""" + async with self.lock: + if self.queue_size >= self.max_queue_size: + return False + + self.queue.append(task) + self.queue_size += 1 + return True + + async def process(self): + """Process queued requests using semaphore for concurrency control""" + while True: + if self.queue: + async with self.semaphore, self.lock: + task = self.queue.popleft() + self.queue_size -= 1 + await task + await asyncio.sleep(0.01) # Yield control to event loop