mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 17:15:01 +08:00
[P/D]Provide bucket algorithm rate limiter for proxy_server (#22643)
Signed-off-by: frankie-ys <yongshengwang@cmbchina.com> Signed-off-by: frankie <wangyongsheng686@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Kuntai Du <kuntai@uchicago.edu>
This commit is contained in:
parent
b2f6c247a9
commit
b2c06509e5
@ -1,63 +1,199 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from quart import Quart, make_response, request
|
from quart import Quart, Response, make_response, request
|
||||||
|
from rate_limiter import RateLimiter
|
||||||
|
from request_queue import RequestQueue
|
||||||
|
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""parse command line arguments"""
|
||||||
|
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
|
||||||
|
|
||||||
|
# Add args
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=300,
|
||||||
|
help="Timeout for backend service requests in seconds (default: 300)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-concurrent",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum concurrent requests to backend services (default: 100)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--queue-size",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Maximum number of requests in the queue (default: 500)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rate-limit",
|
||||||
|
type=int,
|
||||||
|
default=40,
|
||||||
|
help="Maximum requests per second (default: 40)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="Port to run the server on (default: 8000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8100/v1/completions",
|
||||||
|
help="Prefill service endpoint URL",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8200/v1/completions",
|
||||||
|
help="Decode service endpoint URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""parse command line arguments"""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Initialize configuration using command line parameters
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||||
|
MAX_CONCURRENT_REQUESTS = args.max_concurrent
|
||||||
|
REQUEST_QUEUE_SIZE = args.queue_size
|
||||||
|
RATE_LIMIT = args.rate_limit
|
||||||
|
PREFILL_SERVICE_URL = args.prefill_url
|
||||||
|
DECODE_SERVICE_URL = args.decode_url
|
||||||
|
PORT = args.port
|
||||||
|
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
|
|
||||||
|
# Initialize the rate limiter and request queue
|
||||||
|
rate_limiter = RateLimiter(RATE_LIMIT)
|
||||||
|
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
|
||||||
|
|
||||||
|
# Attach the configuration object to the application instance
|
||||||
|
app.config.update(
|
||||||
|
{
|
||||||
|
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||||
|
"rate_limiter": rate_limiter,
|
||||||
|
"request_queue": request_queue,
|
||||||
|
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||||
|
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start queue processing on app startup
|
||||||
|
@app.before_serving
|
||||||
|
async def startup():
|
||||||
|
"""Start request processing task when app starts serving"""
|
||||||
|
asyncio.create_task(request_queue.process())
|
||||||
|
|
||||||
async def forward_request(url, data):
|
async def forward_request(url, data):
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
"""Forward request to backend service with rate limiting and error handling"""
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
async with session.post(url=url, json=data, headers=headers) as response:
|
|
||||||
|
# Use rate limiter as context manager
|
||||||
|
async with (
|
||||||
|
rate_limiter,
|
||||||
|
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=url, json=data, headers=headers
|
||||||
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
# Stream response chunks
|
||||||
if True:
|
|
||||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||||
yield chunk_bytes
|
yield chunk_bytes
|
||||||
else:
|
else:
|
||||||
content = await response.read()
|
# Handle backend service errors
|
||||||
yield content
|
error_text = await response.text()
|
||||||
|
logger.error(
|
||||||
|
"Backend service error: %s - %s",
|
||||||
|
response.status,
|
||||||
|
error_text,
|
||||||
|
)
|
||||||
|
yield b'{"error": "Backend service error"}'
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
# Handle connection errors
|
||||||
|
logger.error("Connection error to %s: %s", url, str(e))
|
||||||
|
yield b'{"error": "Service unavailable"}'
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Handle timeout errors
|
||||||
|
logger.error("Timeout connecting to %s", url)
|
||||||
|
yield b'{"error": "Service timeout"}'
|
||||||
|
|
||||||
|
async def process_request():
|
||||||
@app.route("/v1/completions", methods=["POST"])
|
"""Process a single request through prefill and decode stages"""
|
||||||
async def handle_request():
|
|
||||||
try:
|
try:
|
||||||
original_request_data = await request.get_json()
|
original_request_data = await request.get_json()
|
||||||
|
|
||||||
|
# Create prefill request (max_tokens=1)
|
||||||
prefill_request = original_request_data.copy()
|
prefill_request = original_request_data.copy()
|
||||||
# change max_tokens = 1 to let it only do prefill
|
|
||||||
prefill_request["max_tokens"] = 1
|
prefill_request["max_tokens"] = 1
|
||||||
|
|
||||||
# finish prefill
|
# Execute prefill stage
|
||||||
async for _ in forward_request(
|
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
|
||||||
"http://localhost:8100/v1/completions", prefill_request
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# return decode
|
# Execute decode stage and stream response
|
||||||
generator = forward_request(
|
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
|
||||||
"http://localhost:8200/v1/completions", original_request_data
|
|
||||||
)
|
|
||||||
response = await make_response(generator)
|
response = await make_response(generator)
|
||||||
response.timeout = None
|
response.timeout = None # Disable timeout for streaming response
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
import sys
|
logger.exception("Error processing request")
|
||||||
import traceback
|
return Response(
|
||||||
|
response=b'{"error": "Internal server error"}',
|
||||||
|
status=500,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
exc_info = sys.exc_info()
|
@app.route("/v1/completions", methods=["POST"])
|
||||||
print("Error occurred in disagg prefill proxy server")
|
async def handle_request():
|
||||||
print(e)
|
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
# Create task for request processing
|
||||||
|
task = asyncio.create_task(process_request())
|
||||||
|
|
||||||
|
# Enqueue request or reject if queue is full
|
||||||
|
if not await request_queue.enqueue(task):
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Server busy, try again later"}',
|
||||||
|
status=503,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Return the response from the processing task
|
||||||
|
return await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Handle task cancellation (timeout or queue full)
|
||||||
|
logger.warning("Request cancelled due to timeout or queue full")
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Request cancelled"}',
|
||||||
|
status=503,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the Quart server with host can be set to 0.0.0.0
|
||||||
|
app.run(port=PORT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(port=8000)
|
main()
|
||||||
|
|||||||
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Token bucket rate limiter implementation"""
|
||||||
|
|
||||||
|
def __init__(self, rate_limit):
|
||||||
|
self.rate_limit = rate_limit # Requests per second
|
||||||
|
self.num_available_tokens = rate_limit # Available tokens
|
||||||
|
self.last_refill = time.monotonic() # Last token refill time
|
||||||
|
self.lock = asyncio.Lock() # Synchronization lock
|
||||||
|
|
||||||
|
async def acquire(self):
|
||||||
|
"""Acquire a token from the rate limiter"""
|
||||||
|
while True:
|
||||||
|
async with self.lock:
|
||||||
|
current_time = time.monotonic()
|
||||||
|
elapsed = current_time - self.last_refill
|
||||||
|
|
||||||
|
# Refill num_available_tokens if more than 1 second has passed
|
||||||
|
if elapsed > 1.0:
|
||||||
|
self.num_available_tokens = self.rate_limit
|
||||||
|
self.last_refill = current_time
|
||||||
|
|
||||||
|
# Check if num_available_tokens are available
|
||||||
|
if self.num_available_tokens > 0:
|
||||||
|
self.num_available_tokens -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Calculate wait time if no num_available_tokens available
|
||||||
|
wait_time = 1.0 - elapsed
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Enter async context manager - acquire token"""
|
||||||
|
await self.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Exit async context manager - no cleanup needed"""
|
||||||
|
pass
|
||||||
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class RequestQueue:
|
||||||
|
"""Request queue manager with concurrency control"""
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent, max_queue_size):
|
||||||
|
# Maximum concurrent requests
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
self.max_queue_size = max_queue_size # Maximum queue size
|
||||||
|
# Concurrency control
|
||||||
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
self.queue = deque() # Request queue
|
||||||
|
self.queue_size = 0 # Current queue size
|
||||||
|
self.lock = asyncio.Lock() # Sync queue Lock
|
||||||
|
|
||||||
|
async def enqueue(self, task):
|
||||||
|
"""Add a request task to the queue"""
|
||||||
|
async with self.lock:
|
||||||
|
if self.queue_size >= self.max_queue_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.queue.append(task)
|
||||||
|
self.queue_size += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def process(self):
|
||||||
|
"""Process queued requests using semaphore for concurrency control"""
|
||||||
|
while True:
|
||||||
|
if self.queue:
|
||||||
|
async with self.semaphore, self.lock:
|
||||||
|
task = self.queue.popleft()
|
||||||
|
self.queue_size -= 1
|
||||||
|
await task
|
||||||
|
await asyncio.sleep(0.01) # Yield control to event loop
|
||||||
Loading…
x
Reference in New Issue
Block a user