diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 60cbb58af3d9a..773f52fa38f88 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -68,13 +68,20 @@ def decrement_server_load(request: Request): def load_aware_call(func): @functools.wraps(func) - async def wrapper(*args, raw_request: Request, **kwargs): + async def wrapper(*args, **kwargs): + raw_request = kwargs.get("raw_request", + args[1] if len(args) > 1 else None) + + if raw_request is None: + raise ValueError( + "raw_request required when server load tracking is enabled") + if not raw_request.app.state.enable_server_load_tracking: - return await func(*args, raw_request=raw_request, **kwargs) + return await func(*args, **kwargs) raw_request.app.state.server_load_metrics += 1 try: - response = await func(*args, raw_request=raw_request, **kwargs) + response = await func(*args, **kwargs) except Exception: raw_request.app.state.server_load_metrics -= 1 raise