diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 423b99dbe565c..6c37ce818e6d0 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -33,10 +33,12 @@ async def listen_for_disconnect(request: Request) -> None: while True: message = await request.receive() if message["type"] == "http.disconnect": - if request.app.state.enable_server_load_tracking: - # on timeout/cancellation the BackgroundTask in load_aware_call - # cannot decrement the server load metrics. - # Must be decremented by with_cancellation instead. + # If load tracking is enabled *and* the counter exists, decrement + # it. Combines the previous nested checks into a single condition + # to satisfy the linter rule. + if (getattr(request.app.state, "enable_server_load_tracking", + False) + and hasattr(request.app.state, "server_load_metrics")): request.app.state.server_load_metrics -= 1 break @@ -101,9 +103,14 @@ def load_aware_call(func): raise ValueError( "raw_request required when server load tracking is enabled") - if not raw_request.app.state.enable_server_load_tracking: + if not getattr(raw_request.app.state, "enable_server_load_tracking", + False): return await func(*args, **kwargs) + # ensure the counter exists + if not hasattr(raw_request.app.state, "server_load_metrics"): + raw_request.app.state.server_load_metrics = 0 + raw_request.app.state.server_load_metrics += 1 try: response = await func(*args, **kwargs)