Validate API tokens in constant time (#25781)

Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Russell Bryant 2025-09-27 06:09:26 -04:00 committed by simon-mo
parent bb79c4da2f
commit ee10d7e6ff

View File

@ -3,12 +3,14 @@
import asyncio
import gc
import hashlib
import importlib
import inspect
import json
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
import secrets
import signal
import socket
import tempfile
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}".
if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes
-----
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = {f"Bearer {token}" for token in tokens}
self.api_tokens = [
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]:
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get(
"Authorization") not in self.api_tokens:
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return response(scope, receive, send)