mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:15:01 +08:00
Validate API tokens in constant time (#25781)
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com> Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
bb79c4da2f
commit
ee10d7e6ff
@ -3,12 +3,14 @@
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing.forkserver as forkserver
|
||||
import os
|
||||
import secrets
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization header exists and equals "Bearer {api_key}".
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = {f"Bearer {token}" for token in tokens}
|
||||
self.api_tokens = [
|
||||
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
|
||||
]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and headers.get(
|
||||
"Authorization") not in self.api_tokens:
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return response(scope, receive, send)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user