Validate API tokens in constant time (#25781)

Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Russell Bryant 2025-09-27 06:09:26 -04:00 committed by simon-mo
parent bb79c4da2f
commit ee10d7e6ff

View File

@ -3,12 +3,14 @@
import asyncio import asyncio
import gc import gc
import hashlib
import importlib import importlib
import inspect import inspect
import json import json
import multiprocessing import multiprocessing
import multiprocessing.forkserver as forkserver import multiprocessing.forkserver as forkserver
import os import os
import secrets
import signal import signal
import socket import socket
import tempfile import tempfile
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
class AuthenticationMiddleware: class AuthenticationMiddleware:
""" """
Pure ASGI middleware that authenticates each request by checking Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}". if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes Notes
----- -----
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, tokens: list[str]) -> None: def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app self.app = app
self.api_tokens = {f"Bearer {token}" for token in tokens} self.api_tokens = [
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive, def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]: send: Send) -> Awaitable[None]:
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
url_path = URL(scope=scope).path.removeprefix(root_path) url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope) headers = Headers(scope=scope)
# Type narrow to satisfy mypy. # Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get( if url_path.startswith("/v1") and not self.verify_token(headers):
"Authorization") not in self.api_tokens:
response = JSONResponse(content={"error": "Unauthorized"}, response = JSONResponse(content={"error": "Unauthorized"},
status_code=401) status_code=401)
return response(scope, receive, send) return response(scope, receive, send)